mavlink_core/async_connection/
udp.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10    io::{AsyncRead, ReadBuf},
11    net::UdpSocket,
12};
13
14use crate::connection::udp::config::{UdpConfig, UdpMode};
15use crate::MAVLinkMessageRaw;
16use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message, ReadVersion};
17
18use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
19
20#[cfg(not(feature = "signing"))]
21use crate::{
22    read_versioned_msg_async, read_versioned_raw_message_async, write_versioned_msg_async,
23};
24#[cfg(feature = "signing")]
25use crate::{
26    read_versioned_msg_async_signed, read_versioned_raw_message_async_signed,
27    write_versioned_msg_signed, SigningConfig, SigningData,
28};
29
30struct UdpRead {
31    socket: Arc<UdpSocket>,
32    buffer: VecDeque<u8>,
33    last_recv_address: Option<std::net::SocketAddr>,
34}
35
36const MTU_SIZE: usize = 1500;
37impl AsyncRead for UdpRead {
38    fn poll_read(
39        mut self: core::pin::Pin<&mut Self>,
40        cx: &mut core::task::Context<'_>,
41        buf: &mut ReadBuf<'_>,
42    ) -> Poll<io::Result<()>> {
43        if self.buffer.is_empty() {
44            let mut read_buffer = [0u8; MTU_SIZE];
45            let mut read_buffer = ReadBuf::new(&mut read_buffer);
46
47            match self.socket.poll_recv_from(cx, &mut read_buffer) {
48                Poll::Ready(Ok(address)) => {
49                    let n_buffer = read_buffer.filled().len();
50
51                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
52                    buf.advance(n);
53
54                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
55                    self.last_recv_address = Some(address);
56                    Poll::Ready(Ok(()))
57                }
58                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
59                Poll::Pending => Poll::Pending,
60            }
61        } else {
62            let read_result = self.buffer.read(buf.initialize_unfilled());
63            let result = match read_result {
64                Ok(n) => {
65                    buf.advance(n);
66                    Ok(())
67                }
68                Err(err) => Err(err),
69            };
70            Poll::Ready(result)
71        }
72    }
73}
74
75struct UdpWrite {
76    socket: Arc<UdpSocket>,
77    dest: Option<std::net::SocketAddr>,
78    sequence: u8,
79}
80
81pub struct AsyncUdpConnection {
82    reader: Mutex<AsyncPeekReader<UdpRead>>,
83    writer: Mutex<UdpWrite>,
84    protocol_version: MavlinkVersion,
85    recv_any_version: bool,
86    server: bool,
87    #[cfg(feature = "signing")]
88    signing_data: Option<SigningData>,
89}
90
91impl AsyncUdpConnection {
92    fn new(
93        socket: UdpSocket,
94        server: bool,
95        dest: Option<std::net::SocketAddr>,
96    ) -> io::Result<Self> {
97        let socket = Arc::new(socket);
98        Ok(Self {
99            server,
100            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
101                socket: socket.clone(),
102                buffer: VecDeque::new(),
103                last_recv_address: None,
104            })),
105            writer: Mutex::new(UdpWrite {
106                socket,
107                dest,
108                sequence: 0,
109            }),
110            protocol_version: MavlinkVersion::V2,
111            recv_any_version: false,
112            #[cfg(feature = "signing")]
113            signing_data: None,
114        })
115    }
116}
117
118#[async_trait::async_trait]
119impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
120    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
121        let mut reader = self.reader.lock().await;
122        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
123        loop {
124            #[cfg(not(feature = "signing"))]
125            let result = read_versioned_msg_async(reader.deref_mut(), version).await;
126            #[cfg(feature = "signing")]
127            let result = read_versioned_msg_async_signed(
128                reader.deref_mut(),
129                version,
130                self.signing_data.as_ref(),
131            )
132            .await;
133            if self.server {
134                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
135                    self.writer.lock().await.dest = addr;
136                }
137            }
138            if let ok @ Ok(..) = result {
139                return ok;
140            }
141        }
142    }
143
144    async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
145        let mut reader = self.reader.lock().await;
146        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
147        loop {
148            #[cfg(not(feature = "signing"))]
149            let result =
150                read_versioned_raw_message_async::<M, _>(reader.deref_mut(), version).await;
151            #[cfg(feature = "signing")]
152            let result = read_versioned_raw_message_async_signed::<M, _>(
153                reader.deref_mut(),
154                version,
155                self.signing_data.as_ref(),
156            )
157            .await;
158            if self.server {
159                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
160                    self.writer.lock().await.dest = addr;
161                }
162            }
163            if let ok @ Ok(..) = result {
164                return ok;
165            }
166        }
167    }
168
169    async fn send(
170        &self,
171        header: &MavHeader,
172        data: &M,
173    ) -> Result<usize, crate::error::MessageWriteError> {
174        let mut guard = self.writer.lock().await;
175        let state = &mut *guard;
176
177        let header = MavHeader {
178            sequence: state.sequence,
179            system_id: header.system_id,
180            component_id: header.component_id,
181        };
182
183        state.sequence = state.sequence.wrapping_add(1);
184
185        let len = if let Some(addr) = state.dest {
186            let mut buf = Vec::new();
187            #[cfg(not(feature = "signing"))]
188            write_versioned_msg_async(
189                &mut buf,
190                self.protocol_version,
191                header,
192                data,
193                #[cfg(feature = "signing")]
194                self.signing_data.as_ref(),
195            )
196            .await?;
197            #[cfg(feature = "signing")]
198            write_versioned_msg_signed(
199                &mut buf,
200                self.protocol_version,
201                header,
202                data,
203                #[cfg(feature = "signing")]
204                self.signing_data.as_ref(),
205            )?;
206            state.socket.send_to(&buf, addr).await?
207        } else {
208            0
209        };
210
211        Ok(len)
212    }
213
214    fn set_protocol_version(&mut self, version: MavlinkVersion) {
215        self.protocol_version = version;
216    }
217
218    fn protocol_version(&self) -> MavlinkVersion {
219        self.protocol_version
220    }
221
222    fn set_allow_recv_any_version(&mut self, allow: bool) {
223        self.recv_any_version = allow;
224    }
225
226    fn allow_recv_any_version(&self) -> bool {
227        self.recv_any_version
228    }
229
230    #[cfg(feature = "signing")]
231    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
232        self.signing_data = signing_data.map(SigningData::from_config);
233    }
234}
235
236#[async_trait]
237impl AsyncConnectable for UdpConfig {
238    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
239    where
240        M: Message + Sync + Send,
241    {
242        let (addr, server, dest): (&str, _, _) = match self.mode {
243            UdpMode::Udpin => (&self.address, true, None),
244            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
245        };
246        let socket = UdpSocket::bind(addr).await?;
247        if matches!(self.mode, UdpMode::Udpcast) {
248            socket.set_broadcast(true)?;
249        }
250        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use tokio::io::AsyncReadExt;
258
259    #[tokio::test]
260    async fn test_datagram_buffering() {
261        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
262        let mut udp_reader = UdpRead {
263            socket: receiver_socket.clone(),
264            buffer: VecDeque::new(),
265            last_recv_address: None,
266        };
267        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
268        sender_socket.connect("127.0.0.1:5001").await.unwrap();
269
270        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
271
272        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
273        assert_eq!(n_sent, datagram.len());
274        n_sent = sender_socket.send(&datagram).await.unwrap();
275        assert_eq!(n_sent, datagram.len());
276
277        let mut buf = [0u8; 30];
278
279        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
280        assert_eq!(n_read, 30);
281        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
282
283        n_read = udp_reader.read(&mut buf).await.unwrap();
284        assert_eq!(n_read, 20);
285        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
286
287        n_read = udp_reader.read(&mut buf).await.unwrap();
288        assert_eq!(n_read, 30);
289        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
290
291        n_read = udp_reader.read(&mut buf).await.unwrap();
292        assert_eq!(n_read, 20);
293        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
294    }
295}