mavlink_core/connection/udp/
async.rs1use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10 io::{AsyncRead, ReadBuf},
11 net::UdpSocket,
12};
13
14use crate::MAVLinkMessageRaw;
15use crate::connection::udp::config::{UdpConfig, UdpMode};
16use crate::connection_shared::{
17 ConnectionState, next_send_header, read_message_async, read_raw_message_async,
18 write_message_async,
19};
20use crate::{MavHeader, MavlinkVersion, Message, async_peek_reader::AsyncPeekReader};
21
22use crate::connection::{AsyncConnectable, AsyncMavConnection, get_socket_addr};
23
24#[cfg(feature = "mav2-message-signing")]
25use crate::SigningConfig;
26
27struct UdpRead {
28 socket: Arc<UdpSocket>,
29 buffer: VecDeque<u8>,
30 last_recv_address: Option<std::net::SocketAddr>,
31}
32
33const MTU_SIZE: usize = 1500;
34impl AsyncRead for UdpRead {
35 fn poll_read(
36 mut self: core::pin::Pin<&mut Self>,
37 cx: &mut core::task::Context<'_>,
38 buf: &mut ReadBuf<'_>,
39 ) -> Poll<io::Result<()>> {
40 if self.buffer.is_empty() {
41 let mut read_buffer = [0u8; MTU_SIZE];
42 let mut read_buffer = ReadBuf::new(&mut read_buffer);
43
44 match self.socket.poll_recv_from(cx, &mut read_buffer) {
45 Poll::Ready(Ok(address)) => {
46 let n_buffer = read_buffer.filled().len();
47
48 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
49 buf.advance(n);
50
51 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
52 self.last_recv_address = Some(address);
53 Poll::Ready(Ok(()))
54 }
55 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
56 Poll::Pending => Poll::Pending,
57 }
58 } else {
59 let read_result = self.buffer.read(buf.initialize_unfilled());
60 let result = match read_result {
61 Ok(n) => {
62 buf.advance(n);
63 Ok(())
64 }
65 Err(err) => Err(err),
66 };
67 Poll::Ready(result)
68 }
69 }
70}
71
72struct UdpWrite {
73 socket: Arc<UdpSocket>,
74 dest: Option<std::net::SocketAddr>,
75 sequence: u8,
76}
77
78pub struct AsyncUdpConnection {
79 reader: Mutex<AsyncPeekReader<UdpRead>>,
80 writer: Mutex<UdpWrite>,
81 state: ConnectionState,
82 server: bool,
83}
84
85impl AsyncUdpConnection {
86 fn new(
87 socket: UdpSocket,
88 server: bool,
89 dest: Option<std::net::SocketAddr>,
90 ) -> io::Result<Self> {
91 let socket = Arc::new(socket);
92 Ok(Self {
93 server,
94 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
95 socket: socket.clone(),
96 buffer: VecDeque::new(),
97 last_recv_address: None,
98 })),
99 writer: Mutex::new(UdpWrite {
100 socket,
101 dest,
102 sequence: 0,
103 }),
104 state: ConnectionState::new(),
105 })
106 }
107
108 async fn update_reply_destination(&self, reader: &mut AsyncPeekReader<UdpRead>) {
109 if self.server {
110 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
111 self.writer.lock().await.dest = addr;
112 }
113 }
114 }
115}
116
117#[async_trait::async_trait]
118impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
119 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
120 let mut reader = self.reader.lock().await;
121 loop {
122 let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
123 self.update_reply_destination(reader.deref_mut()).await;
124 if let ok @ Ok(..) = result {
125 return ok;
126 }
127 }
128 }
129
130 async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
131 let mut reader = self.reader.lock().await;
132 loop {
133 let result = read_raw_message_async::<M, _>(reader.deref_mut(), &self.state).await;
134 self.update_reply_destination(reader.deref_mut()).await;
135 if let ok @ Ok(..) = result {
136 return ok;
137 }
138 }
139 }
140
141 async fn try_recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
142 let mut reader = self.reader.lock().await;
143 let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
144 self.update_reply_destination(reader.deref_mut()).await;
145
146 result
147 }
148
149 async fn send(
150 &self,
151 header: &MavHeader,
152 data: &M,
153 ) -> Result<usize, crate::error::MessageWriteError> {
154 let mut guard = self.writer.lock().await;
155 let state = &mut *guard;
156
157 let header = next_send_header(&mut state.sequence, header);
158
159 let len = if let Some(addr) = state.dest {
160 let mut buf = Vec::new();
161 write_message_async(&mut buf, &self.state, header, data).await?;
162 state.socket.send_to(&buf, addr).await?
163 } else {
164 0
165 };
166
167 Ok(len)
168 }
169
170 fn set_protocol_version(&mut self, version: MavlinkVersion) {
171 self.state.set_protocol_version(version);
172 }
173
174 fn protocol_version(&self) -> MavlinkVersion {
175 self.state.protocol_version()
176 }
177
178 fn set_allow_recv_any_version(&mut self, allow: bool) {
179 self.state.set_allow_recv_any_version(allow);
180 }
181
182 fn allow_recv_any_version(&self) -> bool {
183 self.state.allow_recv_any_version()
184 }
185
186 #[cfg(feature = "mav2-message-signing")]
187 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
188 self.state.setup_signing(signing_data);
189 }
190}
191
192#[async_trait]
193impl AsyncConnectable for UdpConfig {
194 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
195 where
196 M: Message + Sync + Send,
197 {
198 let (addr, server, dest): (&str, _, _) = match self.mode {
199 UdpMode::Udpin => (&self.address, true, None),
200 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
201 };
202 let socket = UdpSocket::bind(addr).await?;
203 if matches!(self.mode, UdpMode::UdpBroadcast) {
204 socket.set_broadcast(true)?;
205 }
206 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use tokio::io::AsyncReadExt;
214
215 #[tokio::test]
216 async fn test_datagram_buffering() {
217 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
218 let mut udp_reader = UdpRead {
219 socket: receiver_socket.clone(),
220 buffer: VecDeque::new(),
221 last_recv_address: None,
222 };
223 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
224 sender_socket.connect("127.0.0.1:5001").await.unwrap();
225
226 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
227
228 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
229 assert_eq!(n_sent, datagram.len());
230 n_sent = sender_socket.send(&datagram).await.unwrap();
231 assert_eq!(n_sent, datagram.len());
232
233 let mut buf = [0u8; 30];
234
235 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
236 assert_eq!(n_read, 30);
237 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
238
239 n_read = udp_reader.read(&mut buf).await.unwrap();
240 assert_eq!(n_read, 20);
241 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
242
243 n_read = udp_reader.read(&mut buf).await.unwrap();
244 assert_eq!(n_read, 30);
245 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
246
247 n_read = udp_reader.read(&mut buf).await.unwrap();
248 assert_eq!(n_read, 20);
249 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
250 }
251}