mavlink_core/async_connection/
udp.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::{collections::VecDeque, io::Read, sync::Arc};
5
6use async_trait::async_trait;
7use tokio::{
8    io::{self, AsyncRead, ReadBuf},
9    net::UdpSocket,
10    sync::Mutex,
11};
12
13use crate::{
14    async_peek_reader::AsyncPeekReader,
15    connectable::{UdpConnectable, UdpMode},
16    MavHeader, MavlinkVersion, Message, ReadVersion,
17};
18
19use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
20
21#[cfg(not(feature = "signing"))]
22use crate::{read_versioned_msg_async, write_versioned_msg_async};
23#[cfg(feature = "signing")]
24use crate::{
25    read_versioned_msg_async_signed, write_versioned_msg_signed, SigningConfig, SigningData,
26};
27
28struct UdpRead {
29    socket: Arc<UdpSocket>,
30    buffer: VecDeque<u8>,
31    last_recv_address: Option<std::net::SocketAddr>,
32}
33
34const MTU_SIZE: usize = 1500;
35impl AsyncRead for UdpRead {
36    fn poll_read(
37        mut self: core::pin::Pin<&mut Self>,
38        cx: &mut core::task::Context<'_>,
39        buf: &mut io::ReadBuf<'_>,
40    ) -> Poll<io::Result<()>> {
41        if self.buffer.is_empty() {
42            let mut read_buffer = [0u8; MTU_SIZE];
43            let mut read_buffer = ReadBuf::new(&mut read_buffer);
44
45            match self.socket.poll_recv_from(cx, &mut read_buffer) {
46                Poll::Ready(Ok(address)) => {
47                    let n_buffer = read_buffer.filled().len();
48
49                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
50                    buf.advance(n);
51
52                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
53                    self.last_recv_address = Some(address);
54                    Poll::Ready(Ok(()))
55                }
56                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
57                Poll::Pending => Poll::Pending,
58            }
59        } else {
60            let read_result = self.buffer.read(buf.initialize_unfilled());
61            let result = match read_result {
62                Ok(n) => {
63                    buf.advance(n);
64                    Ok(())
65                }
66                Err(err) => Err(err),
67            };
68            Poll::Ready(result)
69        }
70    }
71}
72
73struct UdpWrite {
74    socket: Arc<UdpSocket>,
75    dest: Option<std::net::SocketAddr>,
76    sequence: u8,
77}
78
79pub struct AsyncUdpConnection {
80    reader: Mutex<AsyncPeekReader<UdpRead>>,
81    writer: Mutex<UdpWrite>,
82    protocol_version: MavlinkVersion,
83    recv_any_version: bool,
84    server: bool,
85    #[cfg(feature = "signing")]
86    signing_data: Option<SigningData>,
87}
88
89impl AsyncUdpConnection {
90    fn new(
91        socket: UdpSocket,
92        server: bool,
93        dest: Option<std::net::SocketAddr>,
94    ) -> io::Result<Self> {
95        let socket = Arc::new(socket);
96        Ok(Self {
97            server,
98            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
99                socket: socket.clone(),
100                buffer: VecDeque::new(),
101                last_recv_address: None,
102            })),
103            writer: Mutex::new(UdpWrite {
104                socket,
105                dest,
106                sequence: 0,
107            }),
108            protocol_version: MavlinkVersion::V2,
109            recv_any_version: false,
110            #[cfg(feature = "signing")]
111            signing_data: None,
112        })
113    }
114}
115
116#[async_trait::async_trait]
117impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
118    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
119        let mut reader = self.reader.lock().await;
120        let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
121        loop {
122            #[cfg(not(feature = "signing"))]
123            let result = read_versioned_msg_async(reader.deref_mut(), version).await;
124            #[cfg(feature = "signing")]
125            let result = read_versioned_msg_async_signed(
126                reader.deref_mut(),
127                version,
128                self.signing_data.as_ref(),
129            )
130            .await;
131            if self.server {
132                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
133                    self.writer.lock().await.dest = addr;
134                }
135            }
136            if let ok @ Ok(..) = result {
137                return ok;
138            }
139        }
140    }
141
142    async fn send(
143        &self,
144        header: &MavHeader,
145        data: &M,
146    ) -> Result<usize, crate::error::MessageWriteError> {
147        let mut guard = self.writer.lock().await;
148        let state = &mut *guard;
149
150        let header = MavHeader {
151            sequence: state.sequence,
152            system_id: header.system_id,
153            component_id: header.component_id,
154        };
155
156        state.sequence = state.sequence.wrapping_add(1);
157
158        let len = if let Some(addr) = state.dest {
159            let mut buf = Vec::new();
160            #[cfg(not(feature = "signing"))]
161            write_versioned_msg_async(
162                &mut buf,
163                self.protocol_version,
164                header,
165                data,
166                #[cfg(feature = "signing")]
167                self.signing_data.as_ref(),
168            )
169            .await?;
170            #[cfg(feature = "signing")]
171            write_versioned_msg_signed(
172                &mut buf,
173                self.protocol_version,
174                header,
175                data,
176                #[cfg(feature = "signing")]
177                self.signing_data.as_ref(),
178            )?;
179            state.socket.send_to(&buf, addr).await?
180        } else {
181            0
182        };
183
184        Ok(len)
185    }
186
187    fn set_protocol_version(&mut self, version: MavlinkVersion) {
188        self.protocol_version = version;
189    }
190
191    fn protocol_version(&self) -> MavlinkVersion {
192        self.protocol_version
193    }
194
195    fn set_allow_recv_any_version(&mut self, allow: bool) {
196        self.recv_any_version = allow
197    }
198
199    fn allow_recv_any_version(&self) -> bool {
200        self.recv_any_version
201    }
202
203    #[cfg(feature = "signing")]
204    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
205        self.signing_data = signing_data.map(SigningData::from_config)
206    }
207}
208
209#[async_trait]
210impl AsyncConnectable for UdpConnectable {
211    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
212    where
213        M: Message + Sync + Send,
214    {
215        let (addr, server, dest): (&str, _, _) = match self.mode {
216            UdpMode::Udpin => (&self.address, true, None),
217            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
218        };
219        let socket = UdpSocket::bind(addr).await?;
220        if matches!(self.mode, UdpMode::Udpcast) {
221            socket.set_broadcast(true)?;
222        }
223        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use io::AsyncReadExt;
231
232    #[tokio::test]
233    async fn test_datagram_buffering() {
234        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
235        let mut udp_reader = UdpRead {
236            socket: receiver_socket.clone(),
237            buffer: VecDeque::new(),
238            last_recv_address: None,
239        };
240        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
241        sender_socket.connect("127.0.0.1:5001").await.unwrap();
242
243        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
244
245        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
246        assert_eq!(n_sent, datagram.len());
247        n_sent = sender_socket.send(&datagram).await.unwrap();
248        assert_eq!(n_sent, datagram.len());
249
250        let mut buf = [0u8; 30];
251
252        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
253        assert_eq!(n_read, 30);
254        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
255
256        n_read = udp_reader.read(&mut buf).await.unwrap();
257        assert_eq!(n_read, 20);
258        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
259
260        n_read = udp_reader.read(&mut buf).await.unwrap();
261        assert_eq!(n_read, 30);
262        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
263
264        n_read = udp_reader.read(&mut buf).await.unwrap();
265        assert_eq!(n_read, 20);
266        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
267    }
268}