mavlink_core/async_connection/
udp.rs1use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10 io::{AsyncRead, ReadBuf},
11 net::UdpSocket,
12};
13
14use crate::{
15 async_peek_reader::AsyncPeekReader,
16 connectable::{UdpConnectable, UdpMode},
17 MavHeader, MavlinkVersion, Message, ReadVersion,
18};
19
20use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
21
22#[cfg(not(feature = "signing"))]
23use crate::{read_versioned_msg_async, write_versioned_msg_async};
24#[cfg(feature = "signing")]
25use crate::{
26 read_versioned_msg_async_signed, write_versioned_msg_signed, SigningConfig, SigningData,
27};
28
29struct UdpRead {
30 socket: Arc<UdpSocket>,
31 buffer: VecDeque<u8>,
32 last_recv_address: Option<std::net::SocketAddr>,
33}
34
35const MTU_SIZE: usize = 1500;
36impl AsyncRead for UdpRead {
37 fn poll_read(
38 mut self: core::pin::Pin<&mut Self>,
39 cx: &mut core::task::Context<'_>,
40 buf: &mut ReadBuf<'_>,
41 ) -> Poll<io::Result<()>> {
42 if self.buffer.is_empty() {
43 let mut read_buffer = [0u8; MTU_SIZE];
44 let mut read_buffer = ReadBuf::new(&mut read_buffer);
45
46 match self.socket.poll_recv_from(cx, &mut read_buffer) {
47 Poll::Ready(Ok(address)) => {
48 let n_buffer = read_buffer.filled().len();
49
50 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
51 buf.advance(n);
52
53 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
54 self.last_recv_address = Some(address);
55 Poll::Ready(Ok(()))
56 }
57 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
58 Poll::Pending => Poll::Pending,
59 }
60 } else {
61 let read_result = self.buffer.read(buf.initialize_unfilled());
62 let result = match read_result {
63 Ok(n) => {
64 buf.advance(n);
65 Ok(())
66 }
67 Err(err) => Err(err),
68 };
69 Poll::Ready(result)
70 }
71 }
72}
73
74struct UdpWrite {
75 socket: Arc<UdpSocket>,
76 dest: Option<std::net::SocketAddr>,
77 sequence: u8,
78}
79
80pub struct AsyncUdpConnection {
81 reader: Mutex<AsyncPeekReader<UdpRead>>,
82 writer: Mutex<UdpWrite>,
83 protocol_version: MavlinkVersion,
84 recv_any_version: bool,
85 server: bool,
86 #[cfg(feature = "signing")]
87 signing_data: Option<SigningData>,
88}
89
90impl AsyncUdpConnection {
91 fn new(
92 socket: UdpSocket,
93 server: bool,
94 dest: Option<std::net::SocketAddr>,
95 ) -> io::Result<Self> {
96 let socket = Arc::new(socket);
97 Ok(Self {
98 server,
99 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
100 socket: socket.clone(),
101 buffer: VecDeque::new(),
102 last_recv_address: None,
103 })),
104 writer: Mutex::new(UdpWrite {
105 socket,
106 dest,
107 sequence: 0,
108 }),
109 protocol_version: MavlinkVersion::V2,
110 recv_any_version: false,
111 #[cfg(feature = "signing")]
112 signing_data: None,
113 })
114 }
115}
116
117#[async_trait::async_trait]
118impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
119 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
120 let mut reader = self.reader.lock().await;
121 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
122 loop {
123 #[cfg(not(feature = "signing"))]
124 let result = read_versioned_msg_async(reader.deref_mut(), version).await;
125 #[cfg(feature = "signing")]
126 let result = read_versioned_msg_async_signed(
127 reader.deref_mut(),
128 version,
129 self.signing_data.as_ref(),
130 )
131 .await;
132 if self.server {
133 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
134 self.writer.lock().await.dest = addr;
135 }
136 }
137 if let ok @ Ok(..) = result {
138 return ok;
139 }
140 }
141 }
142
143 async fn send(
144 &self,
145 header: &MavHeader,
146 data: &M,
147 ) -> Result<usize, crate::error::MessageWriteError> {
148 let mut guard = self.writer.lock().await;
149 let state = &mut *guard;
150
151 let header = MavHeader {
152 sequence: state.sequence,
153 system_id: header.system_id,
154 component_id: header.component_id,
155 };
156
157 state.sequence = state.sequence.wrapping_add(1);
158
159 let len = if let Some(addr) = state.dest {
160 let mut buf = Vec::new();
161 #[cfg(not(feature = "signing"))]
162 write_versioned_msg_async(
163 &mut buf,
164 self.protocol_version,
165 header,
166 data,
167 #[cfg(feature = "signing")]
168 self.signing_data.as_ref(),
169 )
170 .await?;
171 #[cfg(feature = "signing")]
172 write_versioned_msg_signed(
173 &mut buf,
174 self.protocol_version,
175 header,
176 data,
177 #[cfg(feature = "signing")]
178 self.signing_data.as_ref(),
179 )?;
180 state.socket.send_to(&buf, addr).await?
181 } else {
182 0
183 };
184
185 Ok(len)
186 }
187
188 fn set_protocol_version(&mut self, version: MavlinkVersion) {
189 self.protocol_version = version;
190 }
191
192 fn protocol_version(&self) -> MavlinkVersion {
193 self.protocol_version
194 }
195
196 fn set_allow_recv_any_version(&mut self, allow: bool) {
197 self.recv_any_version = allow
198 }
199
200 fn allow_recv_any_version(&self) -> bool {
201 self.recv_any_version
202 }
203
204 #[cfg(feature = "signing")]
205 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
206 self.signing_data = signing_data.map(SigningData::from_config)
207 }
208}
209
210#[async_trait]
211impl AsyncConnectable for UdpConnectable {
212 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
213 where
214 M: Message + Sync + Send,
215 {
216 let (addr, server, dest): (&str, _, _) = match self.mode {
217 UdpMode::Udpin => (&self.address, true, None),
218 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
219 };
220 let socket = UdpSocket::bind(addr).await?;
221 if matches!(self.mode, UdpMode::Udpcast) {
222 socket.set_broadcast(true)?;
223 }
224 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use tokio::io::AsyncReadExt;
232
233 #[tokio::test]
234 async fn test_datagram_buffering() {
235 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
236 let mut udp_reader = UdpRead {
237 socket: receiver_socket.clone(),
238 buffer: VecDeque::new(),
239 last_recv_address: None,
240 };
241 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
242 sender_socket.connect("127.0.0.1:5001").await.unwrap();
243
244 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
245
246 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
247 assert_eq!(n_sent, datagram.len());
248 n_sent = sender_socket.send(&datagram).await.unwrap();
249 assert_eq!(n_sent, datagram.len());
250
251 let mut buf = [0u8; 30];
252
253 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
254 assert_eq!(n_read, 30);
255 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
256
257 n_read = udp_reader.read(&mut buf).await.unwrap();
258 assert_eq!(n_read, 20);
259 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
260
261 n_read = udp_reader.read(&mut buf).await.unwrap();
262 assert_eq!(n_read, 30);
263 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
264
265 n_read = udp_reader.read(&mut buf).await.unwrap();
266 assert_eq!(n_read, 20);
267 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
268 }
269}