mavlink_core/connection/udp/
async.rs1use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10 io::{AsyncRead, AsyncWrite, ReadBuf},
11 net::UdpSocket,
12};
13
14use crate::MAVLinkMessageRaw;
15use crate::connection::udp::config::{UdpConfig, UdpMode};
16use crate::connection_shared::{
17 ConnectionState, next_send_header, read_message_async, read_raw_message_async,
18 write_message_async, write_raw_message_async,
19};
20use crate::{MavHeader, MavlinkVersion, Message, async_peek_reader::AsyncPeekReader};
21
22use crate::connection::{AsyncConnectable, AsyncMavConnection, get_socket_addr};
23
24#[cfg(feature = "mav2-message-signing")]
25use crate::SigningConfig;
26
27struct UdpRead {
28 socket: Arc<UdpSocket>,
29 buffer: VecDeque<u8>,
30 last_recv_address: Option<std::net::SocketAddr>,
31}
32
33const MTU_SIZE: usize = 1500;
34impl AsyncRead for UdpRead {
35 fn poll_read(
36 mut self: core::pin::Pin<&mut Self>,
37 cx: &mut core::task::Context<'_>,
38 buf: &mut ReadBuf<'_>,
39 ) -> Poll<io::Result<()>> {
40 if self.buffer.is_empty() {
41 let mut read_buffer = [0u8; MTU_SIZE];
42 let mut read_buffer = ReadBuf::new(&mut read_buffer);
43
44 match self.socket.poll_recv_from(cx, &mut read_buffer) {
45 Poll::Ready(Ok(address)) => {
46 let n_buffer = read_buffer.filled().len();
47
48 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
49 buf.advance(n);
50
51 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
52 self.last_recv_address = Some(address);
53 Poll::Ready(Ok(()))
54 }
55 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
56 Poll::Pending => Poll::Pending,
57 }
58 } else {
59 let read_result = self.buffer.read(buf.initialize_unfilled());
60 let result = match read_result {
61 Ok(n) => {
62 buf.advance(n);
63 Ok(())
64 }
65 Err(err) => Err(err),
66 };
67 Poll::Ready(result)
68 }
69 }
70}
71
72struct UdpWrite {
73 socket: Arc<UdpSocket>,
74 dest: Option<std::net::SocketAddr>,
75 sequence: u8,
76}
77
78impl AsyncWrite for UdpWrite {
79 fn poll_write(
80 self: core::pin::Pin<&mut Self>,
81 cx: &mut core::task::Context<'_>,
82 buf: &[u8],
83 ) -> Poll<io::Result<usize>> {
84 let this = self.get_mut();
85 let addr = this.dest.expect("`dest` is checked before write");
86
87 match this.socket.poll_send_to(cx, buf, addr) {
88 Poll::Ready(Ok(written)) if written == buf.len() => Poll::Ready(Ok(written)),
89 Poll::Ready(Ok(_)) => Poll::Ready(Err(io::Error::new(
90 io::ErrorKind::WriteZero,
91 "failed to send complete UDP datagram",
92 ))),
93 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
94 Poll::Pending => Poll::Pending,
95 }
96 }
97
98 fn poll_flush(
99 self: core::pin::Pin<&mut Self>,
100 _cx: &mut core::task::Context<'_>,
101 ) -> Poll<io::Result<()>> {
102 Poll::Ready(Ok(()))
103 }
104
105 fn poll_shutdown(
106 self: core::pin::Pin<&mut Self>,
107 _cx: &mut core::task::Context<'_>,
108 ) -> Poll<io::Result<()>> {
109 Poll::Ready(Ok(()))
110 }
111}
112
113pub struct AsyncUdpConnection {
114 reader: Mutex<AsyncPeekReader<UdpRead>>,
115 writer: Mutex<UdpWrite>,
116 state: ConnectionState,
117 server: bool,
118}
119
120impl AsyncUdpConnection {
121 fn new(
122 socket: UdpSocket,
123 server: bool,
124 dest: Option<std::net::SocketAddr>,
125 ) -> io::Result<Self> {
126 let socket = Arc::new(socket);
127 Ok(Self {
128 server,
129 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
130 socket: socket.clone(),
131 buffer: VecDeque::new(),
132 last_recv_address: None,
133 })),
134 writer: Mutex::new(UdpWrite {
135 socket,
136 dest,
137 sequence: 0,
138 }),
139 state: ConnectionState::new(),
140 })
141 }
142
143 async fn update_reply_destination(&self, reader: &mut AsyncPeekReader<UdpRead>) {
144 if self.server {
145 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
146 self.writer.lock().await.dest = addr;
147 }
148 }
149 }
150}
151
152#[async_trait::async_trait]
153impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
154 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
155 let mut reader = self.reader.lock().await;
156 loop {
157 let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
158 self.update_reply_destination(reader.deref_mut()).await;
159 if let ok @ Ok(..) = result {
160 return ok;
161 }
162 }
163 }
164
165 async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
166 let mut reader = self.reader.lock().await;
167 loop {
168 let result = read_raw_message_async::<M, _>(reader.deref_mut(), &self.state).await;
169 self.update_reply_destination(reader.deref_mut()).await;
170 if let ok @ Ok(..) = result {
171 return ok;
172 }
173 }
174 }
175
176 async fn try_recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
177 let mut reader = self.reader.lock().await;
178 let result = read_message_async::<M, _>(reader.deref_mut(), &self.state).await;
179 self.update_reply_destination(reader.deref_mut()).await;
180
181 result
182 }
183
184 async fn send(
185 &self,
186 header: &MavHeader,
187 data: &M,
188 ) -> Result<usize, crate::error::MessageWriteError> {
189 let mut guard = self.writer.lock().await;
190 let writer = &mut *guard;
191
192 let header = next_send_header(&mut writer.sequence, header);
193
194 let len = if writer.dest.is_some() {
195 write_message_async(writer, &self.state, header, data).await?
196 } else {
197 0
198 };
199
200 Ok(len)
201 }
202
203 async fn send_raw(
204 &self,
205 data: &MAVLinkMessageRaw,
206 ) -> Result<usize, crate::error::MessageWriteError> {
207 let mut guard = self.writer.lock().await;
208 let writer = &mut *guard;
209
210 let len = if writer.dest.is_some() {
211 write_raw_message_async(writer, data).await?
212 } else {
213 0
214 };
215
216 Ok(len)
217 }
218
219 fn set_protocol_version(&mut self, version: MavlinkVersion) {
220 self.state.set_protocol_version(version);
221 }
222
223 fn protocol_version(&self) -> MavlinkVersion {
224 self.state.protocol_version()
225 }
226
227 fn set_allow_recv_any_version(&mut self, allow: bool) {
228 self.state.set_allow_recv_any_version(allow);
229 }
230
231 fn allow_recv_any_version(&self) -> bool {
232 self.state.allow_recv_any_version()
233 }
234
235 #[cfg(feature = "mav2-message-signing")]
236 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
237 self.state.setup_signing(signing_data);
238 }
239}
240
241#[async_trait]
242impl AsyncConnectable for UdpConfig {
243 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
244 where
245 M: Message + Sync + Send,
246 {
247 let (addr, server, dest): (&str, _, _) = match self.mode {
248 UdpMode::Udpin => (&self.address, true, None),
249 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
250 };
251 let socket = UdpSocket::bind(addr).await?;
252 if matches!(self.mode, UdpMode::UdpBroadcast) {
253 socket.set_broadcast(true)?;
254 }
255 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use tokio::io::AsyncReadExt;
263
264 #[tokio::test]
265 async fn test_datagram_buffering() {
266 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
267 let mut udp_reader = UdpRead {
268 socket: receiver_socket.clone(),
269 buffer: VecDeque::new(),
270 last_recv_address: None,
271 };
272 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
273 sender_socket.connect("127.0.0.1:5001").await.unwrap();
274
275 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
276
277 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
278 assert_eq!(n_sent, datagram.len());
279 n_sent = sender_socket.send(&datagram).await.unwrap();
280 assert_eq!(n_sent, datagram.len());
281
282 let mut buf = [0u8; 30];
283
284 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
285 assert_eq!(n_read, 30);
286 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
287
288 n_read = udp_reader.read(&mut buf).await.unwrap();
289 assert_eq!(n_read, 20);
290 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
291
292 n_read = udp_reader.read(&mut buf).await.unwrap();
293 assert_eq!(n_read, 30);
294 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
295
296 n_read = udp_reader.read(&mut buf).await.unwrap();
297 assert_eq!(n_read, 20);
298 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
299 }
300}