Skip to main content

mavlink_core/async_connection/
udp.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10    io::{AsyncRead, ReadBuf},
11    net::UdpSocket,
12};
13
14use crate::connection::udp::config::{UdpConfig, UdpMode};
15use crate::MAVLinkMessageRaw;
16use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message, ReadVersion};
17
18use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
19
20#[cfg(not(feature = "mav2-message-signing"))]
21use crate::{
22    read_versioned_msg_async, read_versioned_raw_message_async, write_versioned_msg_async,
23};
24#[cfg(feature = "mav2-message-signing")]
25use crate::{
26    read_versioned_msg_async_signed, read_versioned_raw_message_async_signed,
27    write_versioned_msg_signed, SigningConfig, SigningData,
28};
29
30struct UdpRead {
31    socket: Arc<UdpSocket>,
32    buffer: VecDeque<u8>,
33    last_recv_address: Option<std::net::SocketAddr>,
34}
35
36const MTU_SIZE: usize = 1500;
37impl AsyncRead for UdpRead {
38    fn poll_read(
39        mut self: core::pin::Pin<&mut Self>,
40        cx: &mut core::task::Context<'_>,
41        buf: &mut ReadBuf<'_>,
42    ) -> Poll<io::Result<()>> {
43        if self.buffer.is_empty() {
44            let mut read_buffer = [0u8; MTU_SIZE];
45            let mut read_buffer = ReadBuf::new(&mut read_buffer);
46
47            match self.socket.poll_recv_from(cx, &mut read_buffer) {
48                Poll::Ready(Ok(address)) => {
49                    let n_buffer = read_buffer.filled().len();
50
51                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
52                    buf.advance(n);
53
54                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
55                    self.last_recv_address = Some(address);
56                    Poll::Ready(Ok(()))
57                }
58                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
59                Poll::Pending => Poll::Pending,
60            }
61        } else {
62            let read_result = self.buffer.read(buf.initialize_unfilled());
63            let result = match read_result {
64                Ok(n) => {
65                    buf.advance(n);
66                    Ok(())
67                }
68                Err(err) => Err(err),
69            };
70            Poll::Ready(result)
71        }
72    }
73}
74
75struct UdpWrite {
76    socket: Arc<UdpSocket>,
77    dest: Option<std::net::SocketAddr>,
78    sequence: u8,
79}
80
81pub struct AsyncUdpConnection {
82    reader: Mutex<AsyncPeekReader<UdpRead>>,
83    writer: Mutex<UdpWrite>,
84    protocol_version: MavlinkVersion,
85    recv_any_version: bool,
86    server: bool,
87    #[cfg(feature = "mav2-message-signing")]
88    signing_data: Option<SigningData>,
89}
90
91impl AsyncUdpConnection {
92    fn new(
93        socket: UdpSocket,
94        server: bool,
95        dest: Option<std::net::SocketAddr>,
96    ) -> io::Result<Self> {
97        let socket = Arc::new(socket);
98        Ok(Self {
99            server,
100            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
101                socket: socket.clone(),
102                buffer: VecDeque::new(),
103                last_recv_address: None,
104            })),
105            writer: Mutex::new(UdpWrite {
106                socket,
107                dest,
108                sequence: 0,
109            }),
110            protocol_version: MavlinkVersion::V2,
111            recv_any_version: false,
112            #[cfg(feature = "mav2-message-signing")]
113            signing_data: None,
114        })
115    }
116}
117
118#[async_trait::async_trait]
119impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
120    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
121        let mut reader = self.reader.lock().await;
122        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
123        loop {
124            #[cfg(not(feature = "mav2-message-signing"))]
125            let result = read_versioned_msg_async(reader.deref_mut(), version).await;
126            #[cfg(feature = "mav2-message-signing")]
127            let result = read_versioned_msg_async_signed(
128                reader.deref_mut(),
129                version,
130                self.signing_data.as_ref(),
131            )
132            .await;
133            if self.server {
134                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
135                    self.writer.lock().await.dest = addr;
136                }
137            }
138            if let ok @ Ok(..) = result {
139                return ok;
140            }
141        }
142    }
143
144    async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
145        let mut reader = self.reader.lock().await;
146        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
147        loop {
148            #[cfg(not(feature = "mav2-message-signing"))]
149            let result =
150                read_versioned_raw_message_async::<M, _>(reader.deref_mut(), version).await;
151            #[cfg(feature = "mav2-message-signing")]
152            let result = read_versioned_raw_message_async_signed::<M, _>(
153                reader.deref_mut(),
154                version,
155                self.signing_data.as_ref(),
156            )
157            .await;
158            if self.server {
159                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
160                    self.writer.lock().await.dest = addr;
161                }
162            }
163            if let ok @ Ok(..) = result {
164                return ok;
165            }
166        }
167    }
168
169    async fn try_recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
170        let mut reader = self.reader.lock().await;
171        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
172
173        #[cfg(not(feature = "mav2-message-signing"))]
174        let result = read_versioned_msg_async(reader.deref_mut(), version).await;
175
176        #[cfg(feature = "mav2-message-signing")]
177        let result = read_versioned_msg_async_signed(
178            reader.deref_mut(),
179            version,
180            self.signing_data.as_ref(),
181        )
182        .await;
183
184        if self.server {
185            if let addr @ Some(_) = reader.reader_ref().last_recv_address {
186                self.writer.lock().await.dest = addr;
187            }
188        }
189
190        result
191    }
192
193    async fn send(
194        &self,
195        header: &MavHeader,
196        data: &M,
197    ) -> Result<usize, crate::error::MessageWriteError> {
198        let mut guard = self.writer.lock().await;
199        let state = &mut *guard;
200
201        let header = MavHeader {
202            sequence: state.sequence,
203            system_id: header.system_id,
204            component_id: header.component_id,
205        };
206
207        state.sequence = state.sequence.wrapping_add(1);
208
209        let len = if let Some(addr) = state.dest {
210            let mut buf = Vec::new();
211            #[cfg(not(feature = "mav2-message-signing"))]
212            write_versioned_msg_async(
213                &mut buf,
214                self.protocol_version,
215                header,
216                data,
217                #[cfg(feature = "mav2-message-signing")]
218                self.signing_data.as_ref(),
219            )
220            .await?;
221            #[cfg(feature = "mav2-message-signing")]
222            write_versioned_msg_signed(
223                &mut buf,
224                self.protocol_version,
225                header,
226                data,
227                #[cfg(feature = "mav2-message-signing")]
228                self.signing_data.as_ref(),
229            )?;
230            state.socket.send_to(&buf, addr).await?
231        } else {
232            0
233        };
234
235        Ok(len)
236    }
237
238    fn set_protocol_version(&mut self, version: MavlinkVersion) {
239        self.protocol_version = version;
240    }
241
242    fn protocol_version(&self) -> MavlinkVersion {
243        self.protocol_version
244    }
245
246    fn set_allow_recv_any_version(&mut self, allow: bool) {
247        self.recv_any_version = allow;
248    }
249
250    fn allow_recv_any_version(&self) -> bool {
251        self.recv_any_version
252    }
253
254    #[cfg(feature = "mav2-message-signing")]
255    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
256        self.signing_data = signing_data.map(SigningData::from_config);
257    }
258}
259
260#[async_trait]
261impl AsyncConnectable for UdpConfig {
262    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
263    where
264        M: Message + Sync + Send,
265    {
266        let (addr, server, dest): (&str, _, _) = match self.mode {
267            UdpMode::Udpin => (&self.address, true, None),
268            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
269        };
270        let socket = UdpSocket::bind(addr).await?;
271        if matches!(self.mode, UdpMode::UdpBroadcast) {
272            socket.set_broadcast(true)?;
273        }
274        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use tokio::io::AsyncReadExt;
282
283    #[tokio::test]
284    async fn test_datagram_buffering() {
285        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
286        let mut udp_reader = UdpRead {
287            socket: receiver_socket.clone(),
288            buffer: VecDeque::new(),
289            last_recv_address: None,
290        };
291        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
292        sender_socket.connect("127.0.0.1:5001").await.unwrap();
293
294        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
295
296        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
297        assert_eq!(n_sent, datagram.len());
298        n_sent = sender_socket.send(&datagram).await.unwrap();
299        assert_eq!(n_sent, datagram.len());
300
301        let mut buf = [0u8; 30];
302
303        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
304        assert_eq!(n_read, 30);
305        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
306
307        n_read = udp_reader.read(&mut buf).await.unwrap();
308        assert_eq!(n_read, 20);
309        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
310
311        n_read = udp_reader.read(&mut buf).await.unwrap();
312        assert_eq!(n_read, 30);
313        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
314
315        n_read = udp_reader.read(&mut buf).await.unwrap();
316        assert_eq!(n_read, 20);
317        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
318    }
319}