From dfc52a7f3bf5e0e198e4f1f613a60e666a8604b7 Mon Sep 17 00:00:00 2001 From: ByteColtX Date: Fri, 8 May 2026 22:25:53 +0800 Subject: [PATCH] =?UTF-8?q?fix(session):=20=E6=8E=A7=E5=88=B6=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E6=96=AD=E5=BC=80=E5=90=8E=E8=87=AA=E5=8A=A8=E5=81=9C?= =?UTF-8?q?=E6=92=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 RTSP keepalive 失败回传到 PlaybackSession,并让 CLI 与 tray 在主循环中轮询后自动停播并回传错误。 --- src/cli/commands.rs | 58 +++++++++++++-- src/rtsp/client.rs | 29 +++++--- src/session/raop.rs | 23 +++--- src/session/stream.rs | 153 ++++++++++++++++++++++++++++++++++++++- src/session/transport.rs | 13 ++++ src/ui/tray/mod.rs | 32 ++++++++ src/ui/tray/windows.rs | 44 ++++++++--- 7 files changed, 309 insertions(+), 43 deletions(-) diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 879112c..e3d0dda 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -1,4 +1,6 @@ use std::io::{self, Write}; +use std::sync::mpsc::{self, RecvTimeoutError}; +use std::time::Duration; use crate::app::AppFacade; use crate::discovery::MdnsDiscoveryService; @@ -11,6 +13,13 @@ use super::output::{ }; use super::parse::{CliCommand, CliOptions}; +const PLAYBACK_POLL_INTERVAL: Duration = Duration::from_millis(250); + +enum CaptureExit { + UserRequestedStop, + PlaybackFailed(RairstreamError), +} + pub fn run_cli(cli: CliOptions) -> Result<(), RairstreamError> { let mut facade = AppFacade::new(MdnsDiscoveryService::default())?; @@ -55,10 +64,16 @@ pub fn run_cli(cli: CliOptions) -> Result<(), RairstreamError> { CliCommand::PlayCapture { selectors } => { let session = facade.play_capture(&selectors)?; print_play_capture_started(&selectors); - wait_for_ctrl_c()?; - facade.stop_capture(session)?; - print_play_capture_stopped(&selectors); - Ok(()) + match wait_for_ctrl_c_or_capture_end(&session)? { + CaptureExit::UserRequestedStop => { + facade.stop_capture(session)?; + print_play_capture_stopped(&selectors); + Ok(()) + } + CaptureExit::PlaybackFailed(error) => { + stop_capture_after_failure(&mut facade, session, error) + } + } } } } @@ -77,12 +92,39 @@ fn prompt_pairing_pin(receiver_name: &str) -> Result { Ok(pin.to_string()) } -fn wait_for_ctrl_c() -> Result<(), RairstreamError> { - let (sender, receiver) = std::sync::mpsc::channel(); +fn wait_for_ctrl_c_or_capture_end( + session: &crate::session::PlaybackSession, +) -> Result { + let (sender, receiver) = mpsc::channel(); ctrlc::set_handler(move || { let _ = sender.send(()); }) .map_err(std::io::Error::other)?; - receiver.recv().map_err(std::io::Error::other)?; - Ok(()) + + loop { + match receiver.recv_timeout(PLAYBACK_POLL_INTERVAL) { + Ok(()) => return Ok(CaptureExit::UserRequestedStop), + Err(RecvTimeoutError::Timeout) => { + if let Some(error) = session.transport_error() { + return Ok(CaptureExit::PlaybackFailed(error)); + } + } + Err(RecvTimeoutError::Disconnected) => { + return Err(std::io::Error::other("control-c listener disconnected").into()); + } + } + } +} + +fn stop_capture_after_failure( + facade: &mut AppFacade, + session: crate::session::PlaybackSession, + error: RairstreamError, +) -> Result<(), RairstreamError> { + match facade.stop_capture(session) { + Ok(()) => Err(error), + Err(stop_error) => Err(RairstreamError::Playback { + message: format!("{error}; cleanup failed: {stop_error}"), + }), + } } diff --git a/src/rtsp/client.rs b/src/rtsp/client.rs index 6634060..c5a42b7 100644 --- a/src/rtsp/client.rs +++ b/src/rtsp/client.rs @@ -1,10 +1,7 @@ use std::io::{BufRead, BufReader, Read, Write}; use std::net::TcpStream; use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender}; -use std::sync::{ - Arc, - atomic::{AtomicBool, Ordering}, -}; +use std::sync::{Arc, Mutex}; use std::thread::{self, JoinHandle}; use std::time::Duration; @@ -135,7 +132,7 @@ struct RtspKeepaliveWorker { cseq: u32, interval: Duration, command_rx: Receiver, - transport_terminated: Arc, + transport_error: Arc>>, } impl RtspKeepalive { @@ -145,7 +142,7 @@ impl RtspKeepalive { session_id: String, initial_cseq: u32, interval: Duration, - transport_terminated: Arc, + transport_error: Arc>>, ) -> Result { let endpoint = descriptor.device.endpoint(); let (command_tx, command_rx) = mpsc::channel(); @@ -159,7 +156,7 @@ impl RtspKeepalive { cseq: initial_cseq, interval, command_rx, - transport_terminated, + transport_error, } .run(); }) @@ -214,6 +211,16 @@ impl RtspKeepalive { } impl RtspKeepaliveWorker { + fn record_transport_error(&self, error: AirPlayError) { + let mut state = self + .transport_error + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if state.is_none() { + *state = Some(error); + } + } + fn run(mut self) { let endpoint = self.descriptor.device.endpoint(); @@ -249,7 +256,7 @@ impl RtspKeepaliveWorker { error = %error, "RTSP keepalive received failure response" ); - self.transport_terminated.store(true, Ordering::SeqCst); + self.record_transport_error(error); break; } trace!( @@ -266,7 +273,7 @@ impl RtspKeepaliveWorker { error = %error, "RTSP keepalive failed" ); - self.transport_terminated.store(true, Ordering::SeqCst); + self.record_transport_error(error); break; } } @@ -329,7 +336,7 @@ pub(crate) fn map_connection_error(error: std::io::Error) -> AirPlayError { mod tests { use std::io::{BufRead, BufReader, Read, Write}; use std::net::{TcpListener, TcpStream}; - use std::sync::{Arc, atomic::AtomicBool, mpsc}; + use std::sync::{Arc, Mutex, mpsc}; use std::thread; use std::time::Duration; @@ -419,7 +426,7 @@ mod tests { String::from("deadbeef"), 0, Duration::from_secs(1), - Arc::new(AtomicBool::new(false)), + Arc::new(Mutex::new(None)), ) .unwrap(); diff --git a/src/session/raop.rs b/src/session/raop.rs index 828981a..d633b20 100644 --- a/src/session/raop.rs +++ b/src/session/raop.rs @@ -1,10 +1,7 @@ //! classic `RAOP` 会话、握手与连接生命周期实现。 use std::net::{SocketAddr, ToSocketAddrs, UdpSocket}; -use std::sync::{ - Arc, - atomic::{AtomicBool, Ordering}, -}; +use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; use super::{AirPlayError, SessionDescriptor}; @@ -39,7 +36,7 @@ pub struct RaopConnection { control_target: SocketAddr, timing_responder: Option, rtsp_keepalive: Option, - transport_terminated: Arc, + transport_error: Arc>>, } impl RaopConnection { @@ -76,7 +73,15 @@ impl RaopConnection { #[must_use] pub fn is_terminated(&self) -> bool { - self.transport_terminated.load(Ordering::SeqCst) + self.transport_error().is_some() + } + + #[must_use] + pub fn transport_error(&self) -> Option { + self.transport_error + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() } fn stop_timing_responder(&mut self) { @@ -281,14 +286,14 @@ impl RaopSession { self.apply_record_response(&record_response)?; let keepalive_interval = compute_rtsp_keepalive_interval(setup_reply.session_timeout_secs); - let transport_terminated = Arc::new(AtomicBool::new(false)); + let transport_error = Arc::new(Mutex::new(None)); let rtsp_keepalive = SharedRtspKeepalive::start( rtsp_client, self.descriptor.clone(), setup_reply.session_id.clone(), self.cseq, keepalive_interval, - Arc::clone(&transport_terminated), + Arc::clone(&transport_error), )?; let audio_target = resolve_socket_addr(&self.descriptor.device.host, setup_reply.server_port)?; @@ -306,7 +311,7 @@ impl RaopSession { control_target, timing_responder: Some(timing_responder), rtsp_keepalive: Some(rtsp_keepalive), - transport_terminated, + transport_error, }) } diff --git a/src/session/stream.rs b/src/session/stream.rs index d47ce5a..e43341b 100644 --- a/src/session/stream.rs +++ b/src/session/stream.rs @@ -18,6 +18,14 @@ pub struct PlaybackSession { } impl PlaybackSession { + #[must_use] + pub fn transport_error(&self) -> Option { + self.connections + .iter() + .find_map(|connection| connection.connection.transport_error()) + .map(Into::into) + } + pub fn stop(mut self) -> Result<(), RairstreamError> { let capture_result = match self.capture.take() { Some(capture) => capture.stop().map_err(Into::into), @@ -200,13 +208,24 @@ fn teardown_connections(connections: Vec) -> Result<(), Rairs #[cfg(test)] mod tests { - use std::collections::VecDeque; - use std::time::Instant; + use std::collections::{HashMap, VecDeque}; + use std::io::{BufRead, BufReader, Read, Write}; + use std::net::{TcpListener, TcpStream}; + use std::sync::mpsc; + use std::thread; + use std::time::{Duration, Instant}; use crate::audio::{AudioCaptureError, AudioChunk, AudioFormat, AudioSink}; use crate::error::RairstreamError; + use crate::pairing::ReceiverCredentials; + use crate::receiver::{ + AirPlayGeneration, AuthMethod, DeviceSupport, Receiver, ReceiverCapabilities, ReceiverKind, + }; + use crate::session::AirPlayError; - use super::{chunk_duration, combine_playback_results, stream_chunks}; + use super::{ + PlaybackSession, chunk_duration, combine_playback_results, connect_receivers, stream_chunks, + }; #[derive(Debug, Default)] struct RecordingSink { @@ -296,4 +315,132 @@ mod tests { assert!(matches!(error, AudioCaptureError::InvalidFormat { .. })); } + + #[test] + fn playback_session_reports_keepalive_failure() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let port_listener = listener.try_clone().unwrap(); + let (tx, rx) = mpsc::channel(); + + thread::spawn(move || { + let (mut stream, _) = port_listener.accept().unwrap(); + let mut reader = BufReader::new(stream.try_clone().unwrap()); + + for response in [ + "RTSP/1.0 200 OK\r\nCSeq: 1\r\n\r\n", + "RTSP/1.0 200 OK\r\nCSeq: 2\r\n\r\n", + "RTSP/1.0 200 OK\r\nTransport: RTP/AVP/UDP;unicast;mode=record;server_port=5100;control_port=5101;timing_port=5102\r\nSession: deadbeef;timeout=1\r\n\r\n", + "RTSP/1.0 200 OK\r\nSession: deadbeef\r\n\r\n", + ] { + let _request = read_rtsp_message(&mut reader).unwrap(); + stream.write_all(response.as_bytes()).unwrap(); + stream.flush().unwrap(); + } + + let keepalive_request = read_rtsp_message(&mut reader).unwrap(); + assert!(keepalive_request.starts_with("OPTIONS *")); + stream + .write_all(b"RTSP/1.0 500 Server Error\r\nSession: deadbeef\r\n\r\n") + .unwrap(); + stream.flush().unwrap(); + tx.send(()).unwrap(); + }); + + let paired_receivers: HashMap = HashMap::new(); + let connections = connect_receivers( + &[build_receiver(port)], + AudioFormat::default(), + &paired_receivers, + 100, + ) + .unwrap(); + let session = PlaybackSession { + connections, + capture: None, + }; + + rx.recv_timeout(Duration::from_secs(3)).unwrap(); + let error = wait_for_transport_error(&session).unwrap(); + + assert!(matches!( + error, + RairstreamError::Session(AirPlayError::Protocol { message }) + if message == "RTSP keepalive returned failure status 500" + )); + + session.stop().unwrap(); + } + + fn build_receiver(port: u16) -> Receiver { + Receiver { + id: String::from("speaker"), + name: String::from("Speaker"), + host: String::from("127.0.0.1"), + port, + generation: AirPlayGeneration::AirPlay1, + transport_profile: ReceiverKind::ClassicRaop, + support_level: DeviceSupport::Supported, + auth_method: AuthMethod::None, + capabilities: ReceiverCapabilities::default(), + ..Receiver::default() + } + .with_compat_fields() + } + + fn read_rtsp_message(reader: &mut BufReader) -> Result { + let raw = read_rtsp_message_bytes(reader)?; + Ok(String::from_utf8_lossy(&raw).into_owned()) + } + + fn read_rtsp_message_bytes( + reader: &mut BufReader, + ) -> Result, std::io::Error> { + let mut raw = Vec::new(); + let mut content_length = 0_usize; + + loop { + let mut line = String::new(); + let bytes_read = reader.read_line(&mut line)?; + if bytes_read == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "connection closed", + )); + } + + let normalized = line.trim_end_matches(['\r', '\n']); + raw.extend_from_slice(normalized.as_bytes()); + raw.extend_from_slice(b"\r\n"); + + if normalized.is_empty() { + break; + } + + if let Some((name, value)) = normalized.split_once(':') + && name.trim().eq_ignore_ascii_case("Content-Length") + { + content_length = value.trim().parse::().unwrap(); + } + } + + if content_length > 0 { + let mut body = vec![0_u8; content_length]; + reader.read_exact(&mut body)?; + raw.extend_from_slice(&body); + } + + Ok(raw) + } + + fn wait_for_transport_error(session: &PlaybackSession) -> Option { + let start = Instant::now(); + while start.elapsed() < Duration::from_secs(2) { + if let Some(error) = session.transport_error() { + return Some(error); + } + thread::sleep(Duration::from_millis(25)); + } + None + } } diff --git a/src/session/transport.rs b/src/session/transport.rs index edf1098..0025b64 100644 --- a/src/session/transport.rs +++ b/src/session/transport.rs @@ -82,6 +82,14 @@ impl SessionConnection { } } + #[must_use] + pub fn transport_error(&self) -> Option { + match self { + Self::ClassicRaop(connection) => connection.transport_error(), + Self::ModernAirPlay(connection) => connection.transport_error(), + } + } + #[must_use] pub fn is_terminated(&self) -> bool { match self { @@ -188,6 +196,11 @@ impl ModernAirPlayConnection { self.raop_connection.stream_transport() } + #[must_use] + pub fn transport_error(&self) -> Option { + self.raop_connection.transport_error() + } + #[must_use] pub fn is_terminated(&self) -> bool { self.raop_connection.is_terminated() diff --git a/src/ui/tray/mod.rs b/src/ui/tray/mod.rs index 5fb4808..33351d5 100644 --- a/src/ui/tray/mod.rs +++ b/src/ui/tray/mod.rs @@ -162,6 +162,18 @@ where } } + pub fn poll(&mut self) -> Vec { + let Some(error) = self + .active_session + .as_ref() + .and_then(PlaybackSession::transport_error) + else { + return Vec::new(); + }; + + self.stop_active_session_with_error(error) + } + fn refresh_devices(&mut self) -> Vec { let mut events = Vec::new(); if let Err(error) = self.facade.discover() { @@ -281,6 +293,26 @@ where events } + fn stop_active_session_with_error(&mut self, error: RairstreamError) -> Vec { + let error = if let Some(session) = self.active_session.take() { + match self.facade.stop_capture(session) { + Ok(()) => error, + Err(stop_error) => RairstreamError::Playback { + message: format!("{error}; cleanup failed: {stop_error}"), + }, + } + } else { + error + }; + + let should_exit = self.runtime.stop_streaming(); + let mut events = vec![TrayEvent::Error(error.to_string()), self.snapshot_event()]; + if should_exit { + events.push(TrayEvent::ExitRequested); + } + events + } + fn finish_config_write(&mut self, result: Result<(), RairstreamError>) -> Vec { match result { Ok(()) => vec![self.snapshot_event()], diff --git a/src/ui/tray/windows.rs b/src/ui/tray/windows.rs index 6cda3d8..3a398bc 100644 --- a/src/ui/tray/windows.rs +++ b/src/ui/tray/windows.rs @@ -1,8 +1,9 @@ use std::collections::HashMap; use std::os::windows::process::CommandExt; use std::process::Command; -use std::sync::mpsc::{self, Sender}; +use std::sync::mpsc::{self, RecvTimeoutError, Sender}; use std::thread; +use std::time::Duration; use native_dialog::{DialogBuilder, MessageLevel}; use tray_icon::menu::{ @@ -20,6 +21,7 @@ use super::{TrayCommand, TrayEvent, TrayPhase, TraySnapshot, TrayWorker, tray_re const APP_TITLE: &str = "Rairstream"; const POWERSHELL_CREATE_NO_WINDOW: u32 = 0x0800_0000; +const PLAYBACK_POLL_INTERVAL: Duration = Duration::from_millis(250); const MENU_ID_REFRESH: &str = "refresh"; const MENU_ID_START_STREAMING: &str = "start-streaming"; @@ -409,17 +411,35 @@ fn spawn_worker(proxy: EventLoopProxy) -> Result, worker.snapshot(), ))); - while let Ok(command) = command_receiver.recv() { - let events = worker.handle_command(command); - let should_exit = events - .iter() - .any(|event| matches!(event, TrayEvent::ExitRequested)); - for event in events { - let _ = proxy.send_event(UserEvent::Worker(event)); - } - - if should_exit { - break; + loop { + match command_receiver.recv_timeout(PLAYBACK_POLL_INTERVAL) { + Ok(command) => { + let events = worker.handle_command(command); + let should_exit = events + .iter() + .any(|event| matches!(event, TrayEvent::ExitRequested)); + for event in events { + let _ = proxy.send_event(UserEvent::Worker(event)); + } + + if should_exit { + break; + } + } + Err(RecvTimeoutError::Timeout) => { + let events = worker.poll(); + let should_exit = events + .iter() + .any(|event| matches!(event, TrayEvent::ExitRequested)); + for event in events { + let _ = proxy.send_event(UserEvent::Worker(event)); + } + + if should_exit { + break; + } + } + Err(RecvTimeoutError::Disconnected) => break, } } })