Skip to main content

mavlink_core/connection/udp/
async.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10    io::{AsyncRead, ReadBuf},
11    net::UdpSocket,
12};
13
14use crate::MAVLinkMessageRaw;
15use crate::connection::udp::config::{UdpConfig, UdpMode};
16use crate::connection_shared::{
17    ConnectionState, next_send_header, read_message_async, read_raw_message_async,
18    write_message_async,
19};
20use crate::{MavHeader, MavlinkVersion, Message, async_peek_reader::AsyncPeekReader};
21
22use crate::connection::{AsyncConnectable, AsyncMavConnection, get_socket_addr};
23
24#[cfg(feature = "mav2-message-signing")]
25use crate::SigningConfig;
26
27struct UdpRead {
28    socket: Arc<UdpSocket>,
29    buffer: VecDeque<u8>,
30    last_recv_address: Option<std::net::SocketAddr>,
31}
32
33const MTU_SIZE: usize = 1500;
34impl AsyncRead for UdpRead {
35    fn poll_read(
36        mut self: core::pin::Pin<&mut Self>,
37        cx: &mut core::task::Context<'_>,
38        buf: &mut ReadBuf<'_>,
39    ) -> Poll<io::Result<()>> {
40        if self.buffer.is_empty() {
41            let mut read_buffer = [0u8; MTU_SIZE];
42            let mut read_buffer = ReadBuf::new(&mut read_buffer);
43
44            match self.socket.poll_recv_from(cx, &mut read_buffer) {
45                Poll::Ready(Ok(address)) => {
46                    let n_buffer = read_buffer.filled().len();
47
48                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
49                    buf.advance(n);
50
51                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
52                    self.last_recv_address = Some(address);
53                    Poll::Ready(Ok(()))
54                }
55                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
56                Poll::Pending => Poll::Pending,
57            }
58        } else {
59            let read_result = self.buffer.read(buf.initialize_unfilled());
60            let result = match read_result {
61                Ok(n) => {
62                    buf.advance(n);
63                    Ok(())
64                }
65                Err(err) => Err(err),
66            };
67            Poll::Ready(result)
68        }
69    }
70}
71
72struct UdpWrite {
73    socket: Arc<UdpSocket>,
74    dest: Option<std::net::SocketAddr>,
75    sequence: u8,
76}
77
78pub struct AsyncUdpConnection {
79    reader: Mutex<AsyncPeekReader<UdpRead>>,
80    writer: Mutex<UdpWrite>,
81    state: ConnectionState,
82    server: bool,
83}
84
85impl AsyncUdpConnection {
86    fn new(
87        socket: UdpSocket,
88        server: bool,
89        dest: Option<std::net::SocketAddr>,
90    ) -> io::Result<Self> {
91        let socket = Arc::new(socket);
92        Ok(Self {
93            server,
94            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
95                socket: socket.clone(),
96                buffer: VecDeque::new(),
97                last_recv_address: None,
98            })),
99            writer: Mutex::new(UdpWrite {
100                socket,
101                dest,
102                sequence: 0,
103            }),
104            state: ConnectionState::new(),
105        })
106    }
107
108    async fn update_reply_destination(&self, reader: &mut AsyncPeekReader<UdpRead>) {
109        if self.server {
110            if let addr @ Some(_) = reader.reader_ref().last_recv_address {
111                self.writer.lock().await.dest = addr;
112            }
113        }
114    }
115}
116
117#[async_trait::async_trait]
118impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
119    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
120        let mut reader = self.reader.lock().await;
121        loop {
122            let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
123            self.update_reply_destination(reader.deref_mut()).await;
124            if let ok @ Ok(..) = result {
125                return ok;
126            }
127        }
128    }
129
130    async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
131        let mut reader = self.reader.lock().await;
132        loop {
133            let result = read_raw_message_async::<M, _>(reader.deref_mut(), &self.state).await;
134            self.update_reply_destination(reader.deref_mut()).await;
135            if let ok @ Ok(..) = result {
136                return ok;
137            }
138        }
139    }
140
141    async fn try_recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
142        let mut reader = self.reader.lock().await;
143        let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
144        self.update_reply_destination(reader.deref_mut()).await;
145
146        result
147    }
148
149    async fn send(
150        &self,
151        header: &MavHeader,
152        data: &M,
153    ) -> Result<usize, crate::error::MessageWriteError> {
154        let mut guard = self.writer.lock().await;
155        let state = &mut *guard;
156
157        let header = next_send_header(&mut state.sequence, header);
158
159        let len = if let Some(addr) = state.dest {
160            let mut buf = Vec::new();
161            write_message_async(&mut buf, &self.state, header, data).await?;
162            state.socket.send_to(&buf, addr).await?
163        } else {
164            0
165        };
166
167        Ok(len)
168    }
169
170    fn set_protocol_version(&mut self, version: MavlinkVersion) {
171        self.state.set_protocol_version(version);
172    }
173
174    fn protocol_version(&self) -> MavlinkVersion {
175        self.state.protocol_version()
176    }
177
178    fn set_allow_recv_any_version(&mut self, allow: bool) {
179        self.state.set_allow_recv_any_version(allow);
180    }
181
182    fn allow_recv_any_version(&self) -> bool {
183        self.state.allow_recv_any_version()
184    }
185
186    #[cfg(feature = "mav2-message-signing")]
187    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
188        self.state.setup_signing(signing_data);
189    }
190}
191
192#[async_trait]
193impl AsyncConnectable for UdpConfig {
194    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
195    where
196        M: Message + Sync + Send,
197    {
198        let (addr, server, dest): (&str, _, _) = match self.mode {
199            UdpMode::Udpin => (&self.address, true, None),
200            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
201        };
202        let socket = UdpSocket::bind(addr).await?;
203        if matches!(self.mode, UdpMode::UdpBroadcast) {
204            socket.set_broadcast(true)?;
205        }
206        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use tokio::io::AsyncReadExt;
214
215    #[tokio::test]
216    async fn test_datagram_buffering() {
217        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
218        let mut udp_reader = UdpRead {
219            socket: receiver_socket.clone(),
220            buffer: VecDeque::new(),
221            last_recv_address: None,
222        };
223        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
224        sender_socket.connect("127.0.0.1:5001").await.unwrap();
225
226        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
227
228        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
229        assert_eq!(n_sent, datagram.len());
230        n_sent = sender_socket.send(&datagram).await.unwrap();
231        assert_eq!(n_sent, datagram.len());
232
233        let mut buf = [0u8; 30];
234
235        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
236        assert_eq!(n_read, 30);
237        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
238
239        n_read = udp_reader.read(&mut buf).await.unwrap();
240        assert_eq!(n_read, 20);
241        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
242
243        n_read = udp_reader.read(&mut buf).await.unwrap();
244        assert_eq!(n_read, 30);
245        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
246
247        n_read = udp_reader.read(&mut buf).await.unwrap();
248        assert_eq!(n_read, 20);
249        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
250    }
251}