mavlink_core/async_connection/
udp.rs1use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10 io::{AsyncRead, ReadBuf},
11 net::UdpSocket,
12};
13
14use crate::connection::udp::config::{UdpConfig, UdpMode};
15use crate::MAVLinkMessageRaw;
16use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message, ReadVersion};
17
18use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
19
20#[cfg(not(feature = "signing"))]
21use crate::{read_raw_versioned_msg_async, read_versioned_msg_async, write_versioned_msg_async};
22#[cfg(feature = "signing")]
23use crate::{
24 read_raw_versioned_msg_async_signed, read_versioned_msg_async_signed,
25 write_versioned_msg_signed, SigningConfig, SigningData,
26};
27
28struct UdpRead {
29 socket: Arc<UdpSocket>,
30 buffer: VecDeque<u8>,
31 last_recv_address: Option<std::net::SocketAddr>,
32}
33
34const MTU_SIZE: usize = 1500;
35impl AsyncRead for UdpRead {
36 fn poll_read(
37 mut self: core::pin::Pin<&mut Self>,
38 cx: &mut core::task::Context<'_>,
39 buf: &mut ReadBuf<'_>,
40 ) -> Poll<io::Result<()>> {
41 if self.buffer.is_empty() {
42 let mut read_buffer = [0u8; MTU_SIZE];
43 let mut read_buffer = ReadBuf::new(&mut read_buffer);
44
45 match self.socket.poll_recv_from(cx, &mut read_buffer) {
46 Poll::Ready(Ok(address)) => {
47 let n_buffer = read_buffer.filled().len();
48
49 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
50 buf.advance(n);
51
52 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
53 self.last_recv_address = Some(address);
54 Poll::Ready(Ok(()))
55 }
56 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
57 Poll::Pending => Poll::Pending,
58 }
59 } else {
60 let read_result = self.buffer.read(buf.initialize_unfilled());
61 let result = match read_result {
62 Ok(n) => {
63 buf.advance(n);
64 Ok(())
65 }
66 Err(err) => Err(err),
67 };
68 Poll::Ready(result)
69 }
70 }
71}
72
73struct UdpWrite {
74 socket: Arc<UdpSocket>,
75 dest: Option<std::net::SocketAddr>,
76 sequence: u8,
77}
78
79pub struct AsyncUdpConnection {
80 reader: Mutex<AsyncPeekReader<UdpRead>>,
81 writer: Mutex<UdpWrite>,
82 protocol_version: MavlinkVersion,
83 recv_any_version: bool,
84 server: bool,
85 #[cfg(feature = "signing")]
86 signing_data: Option<SigningData>,
87}
88
89impl AsyncUdpConnection {
90 fn new(
91 socket: UdpSocket,
92 server: bool,
93 dest: Option<std::net::SocketAddr>,
94 ) -> io::Result<Self> {
95 let socket = Arc::new(socket);
96 Ok(Self {
97 server,
98 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
99 socket: socket.clone(),
100 buffer: VecDeque::new(),
101 last_recv_address: None,
102 })),
103 writer: Mutex::new(UdpWrite {
104 socket,
105 dest,
106 sequence: 0,
107 }),
108 protocol_version: MavlinkVersion::V2,
109 recv_any_version: false,
110 #[cfg(feature = "signing")]
111 signing_data: None,
112 })
113 }
114}
115
116#[async_trait::async_trait]
117impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
118 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
119 let mut reader = self.reader.lock().await;
120 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
121 loop {
122 #[cfg(not(feature = "signing"))]
123 let result = read_versioned_msg_async(reader.deref_mut(), version).await;
124 #[cfg(feature = "signing")]
125 let result = read_versioned_msg_async_signed(
126 reader.deref_mut(),
127 version,
128 self.signing_data.as_ref(),
129 )
130 .await;
131 if self.server {
132 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
133 self.writer.lock().await.dest = addr;
134 }
135 }
136 if let ok @ Ok(..) = result {
137 return ok;
138 }
139 }
140 }
141
142 async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
143 let mut reader = self.reader.lock().await;
144 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
145 loop {
146 #[cfg(not(feature = "signing"))]
147 let result = read_raw_versioned_msg_async::<M, _>(reader.deref_mut(), version).await;
148 #[cfg(feature = "signing")]
149 let result = read_raw_versioned_msg_async_signed::<M, _>(
150 reader.deref_mut(),
151 version,
152 self.signing_data.as_ref(),
153 )
154 .await;
155 if self.server {
156 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
157 self.writer.lock().await.dest = addr;
158 }
159 }
160 if let ok @ Ok(..) = result {
161 return ok;
162 }
163 }
164 }
165
166 async fn send(
167 &self,
168 header: &MavHeader,
169 data: &M,
170 ) -> Result<usize, crate::error::MessageWriteError> {
171 let mut guard = self.writer.lock().await;
172 let state = &mut *guard;
173
174 let header = MavHeader {
175 sequence: state.sequence,
176 system_id: header.system_id,
177 component_id: header.component_id,
178 };
179
180 state.sequence = state.sequence.wrapping_add(1);
181
182 let len = if let Some(addr) = state.dest {
183 let mut buf = Vec::new();
184 #[cfg(not(feature = "signing"))]
185 write_versioned_msg_async(
186 &mut buf,
187 self.protocol_version,
188 header,
189 data,
190 #[cfg(feature = "signing")]
191 self.signing_data.as_ref(),
192 )
193 .await?;
194 #[cfg(feature = "signing")]
195 write_versioned_msg_signed(
196 &mut buf,
197 self.protocol_version,
198 header,
199 data,
200 #[cfg(feature = "signing")]
201 self.signing_data.as_ref(),
202 )?;
203 state.socket.send_to(&buf, addr).await?
204 } else {
205 0
206 };
207
208 Ok(len)
209 }
210
211 fn set_protocol_version(&mut self, version: MavlinkVersion) {
212 self.protocol_version = version;
213 }
214
215 fn protocol_version(&self) -> MavlinkVersion {
216 self.protocol_version
217 }
218
219 fn set_allow_recv_any_version(&mut self, allow: bool) {
220 self.recv_any_version = allow;
221 }
222
223 fn allow_recv_any_version(&self) -> bool {
224 self.recv_any_version
225 }
226
227 #[cfg(feature = "signing")]
228 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
229 self.signing_data = signing_data.map(SigningData::from_config);
230 }
231}
232
233#[async_trait]
234impl AsyncConnectable for UdpConfig {
235 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
236 where
237 M: Message + Sync + Send,
238 {
239 let (addr, server, dest): (&str, _, _) = match self.mode {
240 UdpMode::Udpin => (&self.address, true, None),
241 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
242 };
243 let socket = UdpSocket::bind(addr).await?;
244 if matches!(self.mode, UdpMode::Udpcast) {
245 socket.set_broadcast(true)?;
246 }
247 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use tokio::io::AsyncReadExt;
255
256 #[tokio::test]
257 async fn test_datagram_buffering() {
258 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
259 let mut udp_reader = UdpRead {
260 socket: receiver_socket.clone(),
261 buffer: VecDeque::new(),
262 last_recv_address: None,
263 };
264 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
265 sender_socket.connect("127.0.0.1:5001").await.unwrap();
266
267 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
268
269 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
270 assert_eq!(n_sent, datagram.len());
271 n_sent = sender_socket.send(&datagram).await.unwrap();
272 assert_eq!(n_sent, datagram.len());
273
274 let mut buf = [0u8; 30];
275
276 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
277 assert_eq!(n_read, 30);
278 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
279
280 n_read = udp_reader.read(&mut buf).await.unwrap();
281 assert_eq!(n_read, 20);
282 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
283
284 n_read = udp_reader.read(&mut buf).await.unwrap();
285 assert_eq!(n_read, 30);
286 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
287
288 n_read = udp_reader.read(&mut buf).await.unwrap();
289 assert_eq!(n_read, 20);
290 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
291 }
292}