mavlink_core/async_connection/
udp.rs1use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10 io::{AsyncRead, ReadBuf},
11 net::UdpSocket,
12};
13
14use crate::connection::udp::config::{UdpConfig, UdpMode};
15use crate::MAVLinkMessageRaw;
16use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message, ReadVersion};
17
18use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
19
20#[cfg(not(feature = "mav2-message-signing"))]
21use crate::{
22 read_versioned_msg_async, read_versioned_raw_message_async, write_versioned_msg_async,
23};
24#[cfg(feature = "mav2-message-signing")]
25use crate::{
26 read_versioned_msg_async_signed, read_versioned_raw_message_async_signed,
27 write_versioned_msg_signed, SigningConfig, SigningData,
28};
29
30struct UdpRead {
31 socket: Arc<UdpSocket>,
32 buffer: VecDeque<u8>,
33 last_recv_address: Option<std::net::SocketAddr>,
34}
35
36const MTU_SIZE: usize = 1500;
37impl AsyncRead for UdpRead {
38 fn poll_read(
39 mut self: core::pin::Pin<&mut Self>,
40 cx: &mut core::task::Context<'_>,
41 buf: &mut ReadBuf<'_>,
42 ) -> Poll<io::Result<()>> {
43 if self.buffer.is_empty() {
44 let mut read_buffer = [0u8; MTU_SIZE];
45 let mut read_buffer = ReadBuf::new(&mut read_buffer);
46
47 match self.socket.poll_recv_from(cx, &mut read_buffer) {
48 Poll::Ready(Ok(address)) => {
49 let n_buffer = read_buffer.filled().len();
50
51 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
52 buf.advance(n);
53
54 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
55 self.last_recv_address = Some(address);
56 Poll::Ready(Ok(()))
57 }
58 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
59 Poll::Pending => Poll::Pending,
60 }
61 } else {
62 let read_result = self.buffer.read(buf.initialize_unfilled());
63 let result = match read_result {
64 Ok(n) => {
65 buf.advance(n);
66 Ok(())
67 }
68 Err(err) => Err(err),
69 };
70 Poll::Ready(result)
71 }
72 }
73}
74
75struct UdpWrite {
76 socket: Arc<UdpSocket>,
77 dest: Option<std::net::SocketAddr>,
78 sequence: u8,
79}
80
81pub struct AsyncUdpConnection {
82 reader: Mutex<AsyncPeekReader<UdpRead>>,
83 writer: Mutex<UdpWrite>,
84 protocol_version: MavlinkVersion,
85 recv_any_version: bool,
86 server: bool,
87 #[cfg(feature = "mav2-message-signing")]
88 signing_data: Option<SigningData>,
89}
90
91impl AsyncUdpConnection {
92 fn new(
93 socket: UdpSocket,
94 server: bool,
95 dest: Option<std::net::SocketAddr>,
96 ) -> io::Result<Self> {
97 let socket = Arc::new(socket);
98 Ok(Self {
99 server,
100 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
101 socket: socket.clone(),
102 buffer: VecDeque::new(),
103 last_recv_address: None,
104 })),
105 writer: Mutex::new(UdpWrite {
106 socket,
107 dest,
108 sequence: 0,
109 }),
110 protocol_version: MavlinkVersion::V2,
111 recv_any_version: false,
112 #[cfg(feature = "mav2-message-signing")]
113 signing_data: None,
114 })
115 }
116}
117
118#[async_trait::async_trait]
119impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
120 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
121 let mut reader = self.reader.lock().await;
122 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
123 loop {
124 #[cfg(not(feature = "mav2-message-signing"))]
125 let result = read_versioned_msg_async(reader.deref_mut(), version).await;
126 #[cfg(feature = "mav2-message-signing")]
127 let result = read_versioned_msg_async_signed(
128 reader.deref_mut(),
129 version,
130 self.signing_data.as_ref(),
131 )
132 .await;
133 if self.server {
134 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
135 self.writer.lock().await.dest = addr;
136 }
137 }
138 if let ok @ Ok(..) = result {
139 return ok;
140 }
141 }
142 }
143
144 async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
145 let mut reader = self.reader.lock().await;
146 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
147 loop {
148 #[cfg(not(feature = "mav2-message-signing"))]
149 let result =
150 read_versioned_raw_message_async::<M, _>(reader.deref_mut(), version).await;
151 #[cfg(feature = "mav2-message-signing")]
152 let result = read_versioned_raw_message_async_signed::<M, _>(
153 reader.deref_mut(),
154 version,
155 self.signing_data.as_ref(),
156 )
157 .await;
158 if self.server {
159 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
160 self.writer.lock().await.dest = addr;
161 }
162 }
163 if let ok @ Ok(..) = result {
164 return ok;
165 }
166 }
167 }
168
169 async fn try_recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
170 let mut reader = self.reader.lock().await;
171 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
172
173 #[cfg(not(feature = "mav2-message-signing"))]
174 let result = read_versioned_msg_async(reader.deref_mut(), version).await;
175
176 #[cfg(feature = "mav2-message-signing")]
177 let result = read_versioned_msg_async_signed(
178 reader.deref_mut(),
179 version,
180 self.signing_data.as_ref(),
181 )
182 .await;
183
184 if self.server {
185 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
186 self.writer.lock().await.dest = addr;
187 }
188 }
189
190 result
191 }
192
193 async fn send(
194 &self,
195 header: &MavHeader,
196 data: &M,
197 ) -> Result<usize, crate::error::MessageWriteError> {
198 let mut guard = self.writer.lock().await;
199 let state = &mut *guard;
200
201 let header = MavHeader {
202 sequence: state.sequence,
203 system_id: header.system_id,
204 component_id: header.component_id,
205 };
206
207 state.sequence = state.sequence.wrapping_add(1);
208
209 let len = if let Some(addr) = state.dest {
210 let mut buf = Vec::new();
211 #[cfg(not(feature = "mav2-message-signing"))]
212 write_versioned_msg_async(
213 &mut buf,
214 self.protocol_version,
215 header,
216 data,
217 #[cfg(feature = "mav2-message-signing")]
218 self.signing_data.as_ref(),
219 )
220 .await?;
221 #[cfg(feature = "mav2-message-signing")]
222 write_versioned_msg_signed(
223 &mut buf,
224 self.protocol_version,
225 header,
226 data,
227 #[cfg(feature = "mav2-message-signing")]
228 self.signing_data.as_ref(),
229 )?;
230 state.socket.send_to(&buf, addr).await?
231 } else {
232 0
233 };
234
235 Ok(len)
236 }
237
238 fn set_protocol_version(&mut self, version: MavlinkVersion) {
239 self.protocol_version = version;
240 }
241
242 fn protocol_version(&self) -> MavlinkVersion {
243 self.protocol_version
244 }
245
246 fn set_allow_recv_any_version(&mut self, allow: bool) {
247 self.recv_any_version = allow;
248 }
249
250 fn allow_recv_any_version(&self) -> bool {
251 self.recv_any_version
252 }
253
254 #[cfg(feature = "mav2-message-signing")]
255 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
256 self.signing_data = signing_data.map(SigningData::from_config);
257 }
258}
259
260#[async_trait]
261impl AsyncConnectable for UdpConfig {
262 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
263 where
264 M: Message + Sync + Send,
265 {
266 let (addr, server, dest): (&str, _, _) = match self.mode {
267 UdpMode::Udpin => (&self.address, true, None),
268 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
269 };
270 let socket = UdpSocket::bind(addr).await?;
271 if matches!(self.mode, UdpMode::UdpBroadcast) {
272 socket.set_broadcast(true)?;
273 }
274 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use tokio::io::AsyncReadExt;
282
283 #[tokio::test]
284 async fn test_datagram_buffering() {
285 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
286 let mut udp_reader = UdpRead {
287 socket: receiver_socket.clone(),
288 buffer: VecDeque::new(),
289 last_recv_address: None,
290 };
291 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
292 sender_socket.connect("127.0.0.1:5001").await.unwrap();
293
294 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
295
296 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
297 assert_eq!(n_sent, datagram.len());
298 n_sent = sender_socket.send(&datagram).await.unwrap();
299 assert_eq!(n_sent, datagram.len());
300
301 let mut buf = [0u8; 30];
302
303 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
304 assert_eq!(n_read, 30);
305 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
306
307 n_read = udp_reader.read(&mut buf).await.unwrap();
308 assert_eq!(n_read, 20);
309 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
310
311 n_read = udp_reader.read(&mut buf).await.unwrap();
312 assert_eq!(n_read, 30);
313 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
314
315 n_read = udp_reader.read(&mut buf).await.unwrap();
316 assert_eq!(n_read, 20);
317 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
318 }
319}