mavlink_core/async_connection/
udp.rs1use core::{ops::DerefMut, task::Poll};
4use std::{collections::VecDeque, io::Read, sync::Arc};
5
6use async_trait::async_trait;
7use tokio::{
8 io::{self, AsyncRead, ReadBuf},
9 net::UdpSocket,
10 sync::Mutex,
11};
12
13use crate::{
14 async_peek_reader::AsyncPeekReader,
15 connectable::{UdpConnectable, UdpMode},
16 MavHeader, MavlinkVersion, Message, ReadVersion,
17};
18
19use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
20
21#[cfg(not(feature = "signing"))]
22use crate::{read_versioned_msg_async, write_versioned_msg_async};
23#[cfg(feature = "signing")]
24use crate::{
25 read_versioned_msg_async_signed, write_versioned_msg_signed, SigningConfig, SigningData,
26};
27
28struct UdpRead {
29 socket: Arc<UdpSocket>,
30 buffer: VecDeque<u8>,
31 last_recv_address: Option<std::net::SocketAddr>,
32}
33
34const MTU_SIZE: usize = 1500;
35impl AsyncRead for UdpRead {
36 fn poll_read(
37 mut self: core::pin::Pin<&mut Self>,
38 cx: &mut core::task::Context<'_>,
39 buf: &mut io::ReadBuf<'_>,
40 ) -> Poll<io::Result<()>> {
41 if self.buffer.is_empty() {
42 let mut read_buffer = [0u8; MTU_SIZE];
43 let mut read_buffer = ReadBuf::new(&mut read_buffer);
44
45 match self.socket.poll_recv_from(cx, &mut read_buffer) {
46 Poll::Ready(Ok(address)) => {
47 let n_buffer = read_buffer.filled().len();
48
49 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
50 buf.advance(n);
51
52 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
53 self.last_recv_address = Some(address);
54 Poll::Ready(Ok(()))
55 }
56 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
57 Poll::Pending => Poll::Pending,
58 }
59 } else {
60 let read_result = self.buffer.read(buf.initialize_unfilled());
61 let result = match read_result {
62 Ok(n) => {
63 buf.advance(n);
64 Ok(())
65 }
66 Err(err) => Err(err),
67 };
68 Poll::Ready(result)
69 }
70 }
71}
72
73struct UdpWrite {
74 socket: Arc<UdpSocket>,
75 dest: Option<std::net::SocketAddr>,
76 sequence: u8,
77}
78
79pub struct AsyncUdpConnection {
80 reader: Mutex<AsyncPeekReader<UdpRead>>,
81 writer: Mutex<UdpWrite>,
82 protocol_version: MavlinkVersion,
83 recv_any_version: bool,
84 server: bool,
85 #[cfg(feature = "signing")]
86 signing_data: Option<SigningData>,
87}
88
89impl AsyncUdpConnection {
90 fn new(
91 socket: UdpSocket,
92 server: bool,
93 dest: Option<std::net::SocketAddr>,
94 ) -> io::Result<Self> {
95 let socket = Arc::new(socket);
96 Ok(Self {
97 server,
98 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
99 socket: socket.clone(),
100 buffer: VecDeque::new(),
101 last_recv_address: None,
102 })),
103 writer: Mutex::new(UdpWrite {
104 socket,
105 dest,
106 sequence: 0,
107 }),
108 protocol_version: MavlinkVersion::V2,
109 recv_any_version: false,
110 #[cfg(feature = "signing")]
111 signing_data: None,
112 })
113 }
114}
115
116#[async_trait::async_trait]
117impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
118 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
119 let mut reader = self.reader.lock().await;
120 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
121 loop {
122 #[cfg(not(feature = "signing"))]
123 let result = read_versioned_msg_async(reader.deref_mut(), version).await;
124 #[cfg(feature = "signing")]
125 let result = read_versioned_msg_async_signed(
126 reader.deref_mut(),
127 version,
128 self.signing_data.as_ref(),
129 )
130 .await;
131 if self.server {
132 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
133 self.writer.lock().await.dest = addr;
134 }
135 }
136 if let ok @ Ok(..) = result {
137 return ok;
138 }
139 }
140 }
141
142 async fn send(
143 &self,
144 header: &MavHeader,
145 data: &M,
146 ) -> Result<usize, crate::error::MessageWriteError> {
147 let mut guard = self.writer.lock().await;
148 let state = &mut *guard;
149
150 let header = MavHeader {
151 sequence: state.sequence,
152 system_id: header.system_id,
153 component_id: header.component_id,
154 };
155
156 state.sequence = state.sequence.wrapping_add(1);
157
158 let len = if let Some(addr) = state.dest {
159 let mut buf = Vec::new();
160 #[cfg(not(feature = "signing"))]
161 write_versioned_msg_async(
162 &mut buf,
163 self.protocol_version,
164 header,
165 data,
166 #[cfg(feature = "signing")]
167 self.signing_data.as_ref(),
168 )
169 .await?;
170 #[cfg(feature = "signing")]
171 write_versioned_msg_signed(
172 &mut buf,
173 self.protocol_version,
174 header,
175 data,
176 #[cfg(feature = "signing")]
177 self.signing_data.as_ref(),
178 )?;
179 state.socket.send_to(&buf, addr).await?
180 } else {
181 0
182 };
183
184 Ok(len)
185 }
186
187 fn set_protocol_version(&mut self, version: MavlinkVersion) {
188 self.protocol_version = version;
189 }
190
191 fn protocol_version(&self) -> MavlinkVersion {
192 self.protocol_version
193 }
194
195 fn set_allow_recv_any_version(&mut self, allow: bool) {
196 self.recv_any_version = allow
197 }
198
199 fn allow_recv_any_version(&self) -> bool {
200 self.recv_any_version
201 }
202
203 #[cfg(feature = "signing")]
204 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
205 self.signing_data = signing_data.map(SigningData::from_config)
206 }
207}
208
209#[async_trait]
210impl AsyncConnectable for UdpConnectable {
211 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
212 where
213 M: Message + Sync + Send,
214 {
215 let (addr, server, dest): (&str, _, _) = match self.mode {
216 UdpMode::Udpin => (&self.address, true, None),
217 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
218 };
219 let socket = UdpSocket::bind(addr).await?;
220 if matches!(self.mode, UdpMode::Udpcast) {
221 socket.set_broadcast(true)?;
222 }
223 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use io::AsyncReadExt;
231
232 #[tokio::test]
233 async fn test_datagram_buffering() {
234 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
235 let mut udp_reader = UdpRead {
236 socket: receiver_socket.clone(),
237 buffer: VecDeque::new(),
238 last_recv_address: None,
239 };
240 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
241 sender_socket.connect("127.0.0.1:5001").await.unwrap();
242
243 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
244
245 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
246 assert_eq!(n_sent, datagram.len());
247 n_sent = sender_socket.send(&datagram).await.unwrap();
248 assert_eq!(n_sent, datagram.len());
249
250 let mut buf = [0u8; 30];
251
252 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
253 assert_eq!(n_read, 30);
254 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
255
256 n_read = udp_reader.read(&mut buf).await.unwrap();
257 assert_eq!(n_read, 20);
258 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
259
260 n_read = udp_reader.read(&mut buf).await.unwrap();
261 assert_eq!(n_read, 30);
262 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
263
264 n_read = udp_reader.read(&mut buf).await.unwrap();
265 assert_eq!(n_read, 20);
266 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
267 }
268}