mavlink_core/async_connection/
udp.rs

1//! Async UDP MAVLink connection
2
3use core::{ops::DerefMut, task::Poll};
4use std::{collections::VecDeque, io::Read, sync::Arc};
5
6use async_trait::async_trait;
7use tokio::{
8    io::{self, AsyncRead, ReadBuf},
9    net::UdpSocket,
10    sync::Mutex,
11};
12
13use crate::{
14    async_peek_reader::AsyncPeekReader,
15    connectable::{UdpConnectable, UdpMode},
16    MavHeader, MavlinkVersion, Message,
17};
18
19use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
20
21#[cfg(not(feature = "signing"))]
22use crate::{read_versioned_msg_async, write_versioned_msg_async};
23#[cfg(feature = "signing")]
24use crate::{
25    read_versioned_msg_async_signed, write_versioned_msg_signed, SigningConfig, SigningData,
26};
27
28struct UdpRead {
29    socket: Arc<UdpSocket>,
30    buffer: VecDeque<u8>,
31    last_recv_address: Option<std::net::SocketAddr>,
32}
33
34const MTU_SIZE: usize = 1500;
35impl AsyncRead for UdpRead {
36    fn poll_read(
37        mut self: core::pin::Pin<&mut Self>,
38        cx: &mut core::task::Context<'_>,
39        buf: &mut io::ReadBuf<'_>,
40    ) -> Poll<io::Result<()>> {
41        if self.buffer.is_empty() {
42            let mut read_buffer = [0u8; MTU_SIZE];
43            let mut read_buffer = ReadBuf::new(&mut read_buffer);
44
45            match self.socket.poll_recv_from(cx, &mut read_buffer) {
46                Poll::Ready(Ok(address)) => {
47                    let n_buffer = read_buffer.filled().len();
48
49                    let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
50                    buf.advance(n);
51
52                    self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
53                    self.last_recv_address = Some(address);
54                    Poll::Ready(Ok(()))
55                }
56                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
57                Poll::Pending => Poll::Pending,
58            }
59        } else {
60            let read_result = self.buffer.read(buf.initialize_unfilled());
61            let result = match read_result {
62                Ok(n) => {
63                    buf.advance(n);
64                    Ok(())
65                }
66                Err(err) => Err(err),
67            };
68            Poll::Ready(result)
69        }
70    }
71}
72
73struct UdpWrite {
74    socket: Arc<UdpSocket>,
75    dest: Option<std::net::SocketAddr>,
76    sequence: u8,
77}
78
79pub struct AsyncUdpConnection {
80    reader: Mutex<AsyncPeekReader<UdpRead>>,
81    writer: Mutex<UdpWrite>,
82    protocol_version: MavlinkVersion,
83    server: bool,
84    #[cfg(feature = "signing")]
85    signing_data: Option<SigningData>,
86}
87
88impl AsyncUdpConnection {
89    fn new(
90        socket: UdpSocket,
91        server: bool,
92        dest: Option<std::net::SocketAddr>,
93    ) -> io::Result<Self> {
94        let socket = Arc::new(socket);
95        Ok(Self {
96            server,
97            reader: Mutex::new(AsyncPeekReader::new(UdpRead {
98                socket: socket.clone(),
99                buffer: VecDeque::new(),
100                last_recv_address: None,
101            })),
102            writer: Mutex::new(UdpWrite {
103                socket,
104                dest,
105                sequence: 0,
106            }),
107            protocol_version: MavlinkVersion::V2,
108            #[cfg(feature = "signing")]
109            signing_data: None,
110        })
111    }
112}
113
114#[async_trait::async_trait]
115impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
116    async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
117        let mut reader = self.reader.lock().await;
118
119        loop {
120            #[cfg(not(feature = "signing"))]
121            let result = read_versioned_msg_async(reader.deref_mut(), self.protocol_version).await;
122            #[cfg(feature = "signing")]
123            let result = read_versioned_msg_async_signed(
124                reader.deref_mut(),
125                self.protocol_version,
126                self.signing_data.as_ref(),
127            )
128            .await;
129            if self.server {
130                if let addr @ Some(_) = reader.reader_ref().last_recv_address {
131                    self.writer.lock().await.dest = addr;
132                }
133            }
134            if let ok @ Ok(..) = result {
135                return ok;
136            }
137        }
138    }
139
140    async fn send(
141        &self,
142        header: &MavHeader,
143        data: &M,
144    ) -> Result<usize, crate::error::MessageWriteError> {
145        let mut guard = self.writer.lock().await;
146        let state = &mut *guard;
147
148        let header = MavHeader {
149            sequence: state.sequence,
150            system_id: header.system_id,
151            component_id: header.component_id,
152        };
153
154        state.sequence = state.sequence.wrapping_add(1);
155
156        let len = if let Some(addr) = state.dest {
157            let mut buf = Vec::new();
158            #[cfg(not(feature = "signing"))]
159            write_versioned_msg_async(
160                &mut buf,
161                self.protocol_version,
162                header,
163                data,
164                #[cfg(feature = "signing")]
165                self.signing_data.as_ref(),
166            )
167            .await?;
168            #[cfg(feature = "signing")]
169            write_versioned_msg_signed(
170                &mut buf,
171                self.protocol_version,
172                header,
173                data,
174                #[cfg(feature = "signing")]
175                self.signing_data.as_ref(),
176            )?;
177            state.socket.send_to(&buf, addr).await?
178        } else {
179            0
180        };
181
182        Ok(len)
183    }
184
185    fn set_protocol_version(&mut self, version: MavlinkVersion) {
186        self.protocol_version = version;
187    }
188
189    fn get_protocol_version(&self) -> MavlinkVersion {
190        self.protocol_version
191    }
192
193    #[cfg(feature = "signing")]
194    fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
195        self.signing_data = signing_data.map(SigningData::from_config)
196    }
197}
198
199#[async_trait]
200impl AsyncConnectable for UdpConnectable {
201    async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
202    where
203        M: Message + Sync + Send,
204    {
205        let (addr, server, dest): (&str, _, _) = match self.mode {
206            UdpMode::Udpin => (&self.address, true, None),
207            _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
208        };
209        let socket = UdpSocket::bind(addr).await?;
210        if matches!(self.mode, UdpMode::Udpcast) {
211            socket.set_broadcast(true)?;
212        }
213        Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use io::AsyncReadExt;
221
222    #[tokio::test]
223    async fn test_datagram_buffering() {
224        let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
225        let mut udp_reader = UdpRead {
226            socket: receiver_socket.clone(),
227            buffer: VecDeque::new(),
228            last_recv_address: None,
229        };
230        let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
231        sender_socket.connect("127.0.0.1:5001").await.unwrap();
232
233        let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
234
235        let mut n_sent = sender_socket.send(&datagram).await.unwrap();
236        assert_eq!(n_sent, datagram.len());
237        n_sent = sender_socket.send(&datagram).await.unwrap();
238        assert_eq!(n_sent, datagram.len());
239
240        let mut buf = [0u8; 30];
241
242        let mut n_read = udp_reader.read(&mut buf).await.unwrap();
243        assert_eq!(n_read, 30);
244        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
245
246        n_read = udp_reader.read(&mut buf).await.unwrap();
247        assert_eq!(n_read, 20);
248        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
249
250        n_read = udp_reader.read(&mut buf).await.unwrap();
251        assert_eq!(n_read, 30);
252        assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
253
254        n_read = udp_reader.read(&mut buf).await.unwrap();
255        assert_eq!(n_read, 20);
256        assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
257    }
258}