mavlink_core/async_connection/
udp.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10    io::{AsyncRead, ReadBuf},
11    net::UdpSocket,
12};
13
14use crate::connection::udp::config::{UdpConfig, UdpMode};
15use crate::MAVLinkMessageRaw;
16use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message, ReadVersion};
17
18use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
19
20#[cfg(not(feature = "signing"))]
21use crate::{read_raw_versioned_msg_async, read_versioned_msg_async, write_versioned_msg_async};
22#[cfg(feature = "signing")]
23use crate::{
24    read_raw_versioned_msg_async_signed, read_versioned_msg_async_signed,
25    write_versioned_msg_signed, SigningConfig, SigningData,
26};
27
28struct UdpRead {
29    socket: Arc<UdpSocket>,
30    buffer: VecDeque<u8>,
31    last_recv_address: Option<std::net::SocketAddr>,
32}
33
34const MTU_SIZE: usize = 1500;
35impl AsyncRead for UdpRead {
36    fn poll_read(
37        mut self: core::pin::Pin<&mut Self>,
38        cx: &mut core::task::Context<'_>,
39        buf: &mut ReadBuf<'_>,
40    ) -> Poll<io::Result<()>> {
41        if self.buffer.is_empty() {
42            let mut read_buffer = [0u8; MTU_SIZE];
43            let mut read_buffer = ReadBuf::new(&mut read_buffer);
44
45            match self.socket.poll_recv_from(cx, &mut read_buffer) {
46                Poll::Ready(Ok(address)) => {
47                    let n_buffer = read_buffer.filled().len();
48
49                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
50                    buf.advance(n);
51
52                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
53                    self.last_recv_address = Some(address);
54                    Poll::Ready(Ok(()))
55                }
56                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
57                Poll::Pending => Poll::Pending,
58            }
59        } else {
60            let read_result = self.buffer.read(buf.initialize_unfilled());
61            let result = match read_result {
62                Ok(n) => {
63                    buf.advance(n);
64                    Ok(())
65                }
66                Err(err) => Err(err),
67            };
68            Poll::Ready(result)
69        }
70    }
71}
72
73struct UdpWrite {
74    socket: Arc<UdpSocket>,
75    dest: Option<std::net::SocketAddr>,
76    sequence: u8,
77}
78
79pub struct AsyncUdpConnection {
80    reader: Mutex<AsyncPeekReader<UdpRead>>,
81    writer: Mutex<UdpWrite>,
82    protocol_version: MavlinkVersion,
83    recv_any_version: bool,
84    server: bool,
85    #[cfg(feature = "signing")]
86    signing_data: Option<SigningData>,
87}
88
89impl AsyncUdpConnection {
90    fn new(
91        socket: UdpSocket,
92        server: bool,
93        dest: Option<std::net::SocketAddr>,
94    ) -> io::Result<Self> {
95        let socket = Arc::new(socket);
96        Ok(Self {
97            server,
98            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
99                socket: socket.clone(),
100                buffer: VecDeque::new(),
101                last_recv_address: None,
102            })),
103            writer: Mutex::new(UdpWrite {
104                socket,
105                dest,
106                sequence: 0,
107            }),
108            protocol_version: MavlinkVersion::V2,
109            recv_any_version: false,
110            #[cfg(feature = "signing")]
111            signing_data: None,
112        })
113    }
114}
115
116#[async_trait::async_trait]
117impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
118    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
119        let mut reader = self.reader.lock().await;
120        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
121        loop {
122            #[cfg(not(feature = "signing"))]
123            let result = read_versioned_msg_async(reader.deref_mut(), version).await;
124            #[cfg(feature = "signing")]
125            let result = read_versioned_msg_async_signed(
126                reader.deref_mut(),
127                version,
128                self.signing_data.as_ref(),
129            )
130            .await;
131            if self.server {
132                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
133                    self.writer.lock().await.dest = addr;
134                }
135            }
136            if let ok @ Ok(..) = result {
137                return ok;
138            }
139        }
140    }
141
142    async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
143        let mut reader = self.reader.lock().await;
144        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
145        loop {
146            #[cfg(not(feature = "signing"))]
147            let result = read_raw_versioned_msg_async::<M, _>(reader.deref_mut(), version).await;
148            #[cfg(feature = "signing")]
149            let result = read_raw_versioned_msg_async_signed::<M, _>(
150                reader.deref_mut(),
151                version,
152                self.signing_data.as_ref(),
153            )
154            .await;
155            if self.server {
156                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
157                    self.writer.lock().await.dest = addr;
158                }
159            }
160            if let ok @ Ok(..) = result {
161                return ok;
162            }
163        }
164    }
165
166    async fn send(
167        &self,
168        header: &MavHeader,
169        data: &M,
170    ) -> Result<usize, crate::error::MessageWriteError> {
171        let mut guard = self.writer.lock().await;
172        let state = &mut *guard;
173
174        let header = MavHeader {
175            sequence: state.sequence,
176            system_id: header.system_id,
177            component_id: header.component_id,
178        };
179
180        state.sequence = state.sequence.wrapping_add(1);
181
182        let len = if let Some(addr) = state.dest {
183            let mut buf = Vec::new();
184            #[cfg(not(feature = "signing"))]
185            write_versioned_msg_async(
186                &mut buf,
187                self.protocol_version,
188                header,
189                data,
190                #[cfg(feature = "signing")]
191                self.signing_data.as_ref(),
192            )
193            .await?;
194            #[cfg(feature = "signing")]
195            write_versioned_msg_signed(
196                &mut buf,
197                self.protocol_version,
198                header,
199                data,
200                #[cfg(feature = "signing")]
201                self.signing_data.as_ref(),
202            )?;
203            state.socket.send_to(&buf, addr).await?
204        } else {
205            0
206        };
207
208        Ok(len)
209    }
210
211    fn set_protocol_version(&mut self, version: MavlinkVersion) {
212        self.protocol_version = version;
213    }
214
215    fn protocol_version(&self) -> MavlinkVersion {
216        self.protocol_version
217    }
218
219    fn set_allow_recv_any_version(&mut self, allow: bool) {
220        self.recv_any_version = allow;
221    }
222
223    fn allow_recv_any_version(&self) -> bool {
224        self.recv_any_version
225    }
226
227    #[cfg(feature = "signing")]
228    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
229        self.signing_data = signing_data.map(SigningData::from_config);
230    }
231}
232
233#[async_trait]
234impl AsyncConnectable for UdpConfig {
235    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
236    where
237        M: Message + Sync + Send,
238    {
239        let (addr, server, dest): (&str, _, _) = match self.mode {
240            UdpMode::Udpin => (&self.address, true, None),
241            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
242        };
243        let socket = UdpSocket::bind(addr).await?;
244        if matches!(self.mode, UdpMode::Udpcast) {
245            socket.set_broadcast(true)?;
246        }
247        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use tokio::io::AsyncReadExt;
255
256    #[tokio::test]
257    async fn test_datagram_buffering() {
258        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
259        let mut udp_reader = UdpRead {
260            socket: receiver_socket.clone(),
261            buffer: VecDeque::new(),
262            last_recv_address: None,
263        };
264        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
265        sender_socket.connect("127.0.0.1:5001").await.unwrap();
266
267        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
268
269        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
270        assert_eq!(n_sent, datagram.len());
271        n_sent = sender_socket.send(&datagram).await.unwrap();
272        assert_eq!(n_sent, datagram.len());
273
274        let mut buf = [0u8; 30];
275
276        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
277        assert_eq!(n_read, 30);
278        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
279
280        n_read = udp_reader.read(&mut buf).await.unwrap();
281        assert_eq!(n_read, 20);
282        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
283
284        n_read = udp_reader.read(&mut buf).await.unwrap();
285        assert_eq!(n_read, 30);
286        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
287
288        n_read = udp_reader.read(&mut buf).await.unwrap();
289        assert_eq!(n_read, 20);
290        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
291    }
292}