mavlink_core/async_connection/
udp.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10    io::{AsyncRead, ReadBuf},
11    net::UdpSocket,
12};
13
14use crate::{
15    async_peek_reader::AsyncPeekReader,
16    connectable::{UdpConnectable, UdpMode},
17    MavHeader, MavlinkVersion, Message, ReadVersion,
18};
19
20use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
21
22#[cfg(not(feature = "signing"))]
23use crate::{read_versioned_msg_async, write_versioned_msg_async};
24#[cfg(feature = "signing")]
25use crate::{
26    read_versioned_msg_async_signed, write_versioned_msg_signed, SigningConfig, SigningData,
27};
28
29struct UdpRead {
30    socket: Arc<UdpSocket>,
31    buffer: VecDeque<u8>,
32    last_recv_address: Option<std::net::SocketAddr>,
33}
34
35const MTU_SIZE: usize = 1500;
36impl AsyncRead for UdpRead {
37    fn poll_read(
38        mut self: core::pin::Pin<&mut Self>,
39        cx: &mut core::task::Context<'_>,
40        buf: &mut ReadBuf<'_>,
41    ) -> Poll<io::Result<()>> {
42        if self.buffer.is_empty() {
43            let mut read_buffer = [0u8; MTU_SIZE];
44            let mut read_buffer = ReadBuf::new(&mut read_buffer);
45
46            match self.socket.poll_recv_from(cx, &mut read_buffer) {
47                Poll::Ready(Ok(address)) => {
48                    let n_buffer = read_buffer.filled().len();
49
50                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
51                    buf.advance(n);
52
53                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
54                    self.last_recv_address = Some(address);
55                    Poll::Ready(Ok(()))
56                }
57                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
58                Poll::Pending => Poll::Pending,
59            }
60        } else {
61            let read_result = self.buffer.read(buf.initialize_unfilled());
62            let result = match read_result {
63                Ok(n) => {
64                    buf.advance(n);
65                    Ok(())
66                }
67                Err(err) => Err(err),
68            };
69            Poll::Ready(result)
70        }
71    }
72}
73
74struct UdpWrite {
75    socket: Arc<UdpSocket>,
76    dest: Option<std::net::SocketAddr>,
77    sequence: u8,
78}
79
80pub struct AsyncUdpConnection {
81    reader: Mutex<AsyncPeekReader<UdpRead>>,
82    writer: Mutex<UdpWrite>,
83    protocol_version: MavlinkVersion,
84    recv_any_version: bool,
85    server: bool,
86    #[cfg(feature = "signing")]
87    signing_data: Option<SigningData>,
88}
89
90impl AsyncUdpConnection {
91    fn new(
92        socket: UdpSocket,
93        server: bool,
94        dest: Option<std::net::SocketAddr>,
95    ) -> io::Result<Self> {
96        let socket = Arc::new(socket);
97        Ok(Self {
98            server,
99            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
100                socket: socket.clone(),
101                buffer: VecDeque::new(),
102                last_recv_address: None,
103            })),
104            writer: Mutex::new(UdpWrite {
105                socket,
106                dest,
107                sequence: 0,
108            }),
109            protocol_version: MavlinkVersion::V2,
110            recv_any_version: false,
111            #[cfg(feature = "signing")]
112            signing_data: None,
113        })
114    }
115}
116
117#[async_trait::async_trait]
118impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
119    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
120        let mut reader = self.reader.lock().await;
121        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
122        loop {
123            #[cfg(not(feature = "signing"))]
124            let result = read_versioned_msg_async(reader.deref_mut(), version).await;
125            #[cfg(feature = "signing")]
126            let result = read_versioned_msg_async_signed(
127                reader.deref_mut(),
128                version,
129                self.signing_data.as_ref(),
130            )
131            .await;
132            if self.server {
133                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
134                    self.writer.lock().await.dest = addr;
135                }
136            }
137            if let ok @ Ok(..) = result {
138                return ok;
139            }
140        }
141    }
142
143    async fn send(
144        &self,
145        header: &MavHeader,
146        data: &M,
147    ) -> Result<usize, crate::error::MessageWriteError> {
148        let mut guard = self.writer.lock().await;
149        let state = &mut *guard;
150
151        let header = MavHeader {
152            sequence: state.sequence,
153            system_id: header.system_id,
154            component_id: header.component_id,
155        };
156
157        state.sequence = state.sequence.wrapping_add(1);
158
159        let len = if let Some(addr) = state.dest {
160            let mut buf = Vec::new();
161            #[cfg(not(feature = "signing"))]
162            write_versioned_msg_async(
163                &mut buf,
164                self.protocol_version,
165                header,
166                data,
167                #[cfg(feature = "signing")]
168                self.signing_data.as_ref(),
169            )
170            .await?;
171            #[cfg(feature = "signing")]
172            write_versioned_msg_signed(
173                &mut buf,
174                self.protocol_version,
175                header,
176                data,
177                #[cfg(feature = "signing")]
178                self.signing_data.as_ref(),
179            )?;
180            state.socket.send_to(&buf, addr).await?
181        } else {
182            0
183        };
184
185        Ok(len)
186    }
187
188    fn set_protocol_version(&mut self, version: MavlinkVersion) {
189        self.protocol_version = version;
190    }
191
192    fn protocol_version(&self) -> MavlinkVersion {
193        self.protocol_version
194    }
195
196    fn set_allow_recv_any_version(&mut self, allow: bool) {
197        self.recv_any_version = allow
198    }
199
200    fn allow_recv_any_version(&self) -> bool {
201        self.recv_any_version
202    }
203
204    #[cfg(feature = "signing")]
205    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
206        self.signing_data = signing_data.map(SigningData::from_config)
207    }
208}
209
210#[async_trait]
211impl AsyncConnectable for UdpConnectable {
212    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
213    where
214        M: Message + Sync + Send,
215    {
216        let (addr, server, dest): (&str, _, _) = match self.mode {
217            UdpMode::Udpin => (&self.address, true, None),
218            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
219        };
220        let socket = UdpSocket::bind(addr).await?;
221        if matches!(self.mode, UdpMode::Udpcast) {
222            socket.set_broadcast(true)?;
223        }
224        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use tokio::io::AsyncReadExt;
232
233    #[tokio::test]
234    async fn test_datagram_buffering() {
235        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
236        let mut udp_reader = UdpRead {
237            socket: receiver_socket.clone(),
238            buffer: VecDeque::new(),
239            last_recv_address: None,
240        };
241        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
242        sender_socket.connect("127.0.0.1:5001").await.unwrap();
243
244        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
245
246        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
247        assert_eq!(n_sent, datagram.len());
248        n_sent = sender_socket.send(&datagram).await.unwrap();
249        assert_eq!(n_sent, datagram.len());
250
251        let mut buf = [0u8; 30];
252
253        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
254        assert_eq!(n_read, 30);
255        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
256
257        n_read = udp_reader.read(&mut buf).await.unwrap();
258        assert_eq!(n_read, 20);
259        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
260
261        n_read = udp_reader.read(&mut buf).await.unwrap();
262        assert_eq!(n_read, 30);
263        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
264
265        n_read = udp_reader.read(&mut buf).await.unwrap();
266        assert_eq!(n_read, 20);
267        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
268    }
269}