Skip to main content

mavlink_core/connection/udp/
async.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10    io::{AsyncRead, AsyncWrite, ReadBuf},
11    net::UdpSocket,
12};
13
14use crate::MAVLinkMessageRaw;
15use crate::connection::udp::config::{UdpConfig, UdpMode};
16use crate::connection_shared::{
17    ConnectionState, next_send_header, read_message_async, read_raw_message_async,
18    write_message_async, write_raw_message_async,
19};
20use crate::{MavHeader, MavlinkVersion, Message, async_peek_reader::AsyncPeekReader};
21
22use crate::connection::{AsyncConnectable, AsyncMavConnection, get_socket_addr};
23
24#[cfg(feature = "mav2-message-signing")]
25use crate::SigningConfig;
26
27struct UdpRead {
28    socket: Arc<UdpSocket>,
29    buffer: VecDeque<u8>,
30    last_recv_address: Option<std::net::SocketAddr>,
31}
32
33const MTU_SIZE: usize = 1500;
34impl AsyncRead for UdpRead {
35    fn poll_read(
36        mut self: core::pin::Pin<&mut Self>,
37        cx: &mut core::task::Context<'_>,
38        buf: &mut ReadBuf<'_>,
39    ) -> Poll<io::Result<()>> {
40        if self.buffer.is_empty() {
41            let mut read_buffer = [0u8; MTU_SIZE];
42            let mut read_buffer = ReadBuf::new(&mut read_buffer);
43
44            match self.socket.poll_recv_from(cx, &mut read_buffer) {
45                Poll::Ready(Ok(address)) => {
46                    let n_buffer = read_buffer.filled().len();
47
48                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
49                    buf.advance(n);
50
51                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
52                    self.last_recv_address = Some(address);
53                    Poll::Ready(Ok(()))
54                }
55                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
56                Poll::Pending => Poll::Pending,
57            }
58        } else {
59            let read_result = self.buffer.read(buf.initialize_unfilled());
60            let result = match read_result {
61                Ok(n) => {
62                    buf.advance(n);
63                    Ok(())
64                }
65                Err(err) => Err(err),
66            };
67            Poll::Ready(result)
68        }
69    }
70}
71
72struct UdpWrite {
73    socket: Arc<UdpSocket>,
74    dest: Option<std::net::SocketAddr>,
75    sequence: u8,
76}
77
78impl AsyncWrite for UdpWrite {
79    fn poll_write(
80        self: core::pin::Pin<&mut Self>,
81        cx: &mut core::task::Context<'_>,
82        buf: &[u8],
83    ) -> Poll<io::Result<usize>> {
84        let this = self.get_mut();
85        let addr = this.dest.expect("`dest` is checked before write");
86
87        match this.socket.poll_send_to(cx, buf, addr) {
88            Poll::Ready(Ok(written)) if written == buf.len() => Poll::Ready(Ok(written)),
89            Poll::Ready(Ok(_)) => Poll::Ready(Err(io::Error::new(
90                io::ErrorKind::WriteZero,
91                "failed to send complete UDP datagram",
92            ))),
93            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
94            Poll::Pending => Poll::Pending,
95        }
96    }
97
98    fn poll_flush(
99        self: core::pin::Pin<&mut Self>,
100        _cx: &mut core::task::Context<'_>,
101    ) -> Poll<io::Result<()>> {
102        Poll::Ready(Ok(()))
103    }
104
105    fn poll_shutdown(
106        self: core::pin::Pin<&mut Self>,
107        _cx: &mut core::task::Context<'_>,
108    ) -> Poll<io::Result<()>> {
109        Poll::Ready(Ok(()))
110    }
111}
112
113pub struct AsyncUdpConnection {
114    reader: Mutex<AsyncPeekReader<UdpRead>>,
115    writer: Mutex<UdpWrite>,
116    state: ConnectionState,
117    server: bool,
118}
119
120impl AsyncUdpConnection {
121    fn new(
122        socket: UdpSocket,
123        server: bool,
124        dest: Option<std::net::SocketAddr>,
125    ) -> io::Result<Self> {
126        let socket = Arc::new(socket);
127        Ok(Self {
128            server,
129            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
130                socket: socket.clone(),
131                buffer: VecDeque::new(),
132                last_recv_address: None,
133            })),
134            writer: Mutex::new(UdpWrite {
135                socket,
136                dest,
137                sequence: 0,
138            }),
139            state: ConnectionState::new(),
140        })
141    }
142
143    async fn update_reply_destination(&self, reader: &mut AsyncPeekReader<UdpRead>) {
144        if self.server {
145            if let addr @ Some(_) = reader.reader_ref().last_recv_address {
146                self.writer.lock().await.dest = addr;
147            }
148        }
149    }
150}
151
152#[async_trait::async_trait]
153impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
154    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
155        let mut reader = self.reader.lock().await;
156        loop {
157            let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
158            self.update_reply_destination(reader.deref_mut()).await;
159            if let ok @ Ok(..) = result {
160                return ok;
161            }
162        }
163    }
164
165    async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
166        let mut reader = self.reader.lock().await;
167        loop {
168            let result = read_raw_message_async::<M, _>(reader.deref_mut(), &self.state).await;
169            self.update_reply_destination(reader.deref_mut()).await;
170            if let ok @ Ok(..) = result {
171                return ok;
172            }
173        }
174    }
175
176    async fn try_recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
177        let mut reader = self.reader.lock().await;
178        let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
179        self.update_reply_destination(reader.deref_mut()).await;
180
181        result
182    }
183
184    async fn send(
185        &self,
186        header: &MavHeader,
187        data: &M,
188    ) -> Result<usize, crate::error::MessageWriteError> {
189        let mut guard = self.writer.lock().await;
190        let writer = &mut *guard;
191
192        let header = next_send_header(&mut writer.sequence, header);
193
194        let len = if writer.dest.is_some() {
195            write_message_async(writer, &self.state, header, data).await?
196        } else {
197            0
198        };
199
200        Ok(len)
201    }
202
203    async fn send_raw(
204        &self,
205        data: &MAVLinkMessageRaw,
206    ) -> Result<usize, crate::error::MessageWriteError> {
207        let mut guard = self.writer.lock().await;
208        let writer = &mut *guard;
209
210        let len = if writer.dest.is_some() {
211            write_raw_message_async(writer, data).await?
212        } else {
213            0
214        };
215
216        Ok(len)
217    }
218
219    fn set_protocol_version(&mut self, version: MavlinkVersion) {
220        self.state.set_protocol_version(version);
221    }
222
223    fn protocol_version(&self) -> MavlinkVersion {
224        self.state.protocol_version()
225    }
226
227    fn set_allow_recv_any_version(&mut self, allow: bool) {
228        self.state.set_allow_recv_any_version(allow);
229    }
230
231    fn allow_recv_any_version(&self) -> bool {
232        self.state.allow_recv_any_version()
233    }
234
235    #[cfg(feature = "mav2-message-signing")]
236    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
237        self.state.setup_signing(signing_data);
238    }
239}
240
241#[async_trait]
242impl AsyncConnectable for UdpConfig {
243    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
244    where
245        M: Message + Sync + Send,
246    {
247        let (addr, server, dest): (&str, _, _) = match self.mode {
248            UdpMode::Udpin => (&self.address, true, None),
249            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
250        };
251        let socket = UdpSocket::bind(addr).await?;
252        if matches!(self.mode, UdpMode::UdpBroadcast) {
253            socket.set_broadcast(true)?;
254        }
255        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use tokio::io::AsyncReadExt;
263
264    #[tokio::test]
265    async fn test_datagram_buffering() {
266        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
267        let mut udp_reader = UdpRead {
268            socket: receiver_socket.clone(),
269            buffer: VecDeque::new(),
270            last_recv_address: None,
271        };
272        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
273        sender_socket.connect("127.0.0.1:5001").await.unwrap();
274
275        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
276
277        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
278        assert_eq!(n_sent, datagram.len());
279        n_sent = sender_socket.send(&datagram).await.unwrap();
280        assert_eq!(n_sent, datagram.len());
281
282        let mut buf = [0u8; 30];
283
284        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
285        assert_eq!(n_read, 30);
286        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
287
288        n_read = udp_reader.read(&mut buf).await.unwrap();
289        assert_eq!(n_read, 20);
290        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
291
292        n_read = udp_reader.read(&mut buf).await.unwrap();
293        assert_eq!(n_read, 30);
294        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
295
296        n_read = udp_reader.read(&mut buf).await.unwrap();
297        assert_eq!(n_read, 20);
298        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
299    }
300}