mavlink_core/async_connection/
udp.rs1use core::{ops::DerefMut, task::Poll};
4use std::io;
5use std::{collections::VecDeque, io::Read, sync::Arc};
6
7use async_trait::async_trait;
8use futures::lock::Mutex;
9use tokio::{
10 io::{AsyncRead, ReadBuf},
11 net::UdpSocket,
12};
13
14use crate::connection::udp::config::{UdpConfig, UdpMode};
15use crate::MAVLinkMessageRaw;
16use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message, ReadVersion};
17
18use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
19
20#[cfg(not(feature = "signing"))]
21use crate::{
22 read_versioned_msg_async, read_versioned_raw_message_async, write_versioned_msg_async,
23};
24#[cfg(feature = "signing")]
25use crate::{
26 read_versioned_msg_async_signed, read_versioned_raw_message_async_signed,
27 write_versioned_msg_signed, SigningConfig, SigningData,
28};
29
30struct UdpRead {
31 socket: Arc<UdpSocket>,
32 buffer: VecDeque<u8>,
33 last_recv_address: Option<std::net::SocketAddr>,
34}
35
36const MTU_SIZE: usize = 1500;
37impl AsyncRead for UdpRead {
38 fn poll_read(
39 mut self: core::pin::Pin<&mut Self>,
40 cx: &mut core::task::Context<'_>,
41 buf: &mut ReadBuf<'_>,
42 ) -> Poll<io::Result<()>> {
43 if self.buffer.is_empty() {
44 let mut read_buffer = [0u8; MTU_SIZE];
45 let mut read_buffer = ReadBuf::new(&mut read_buffer);
46
47 match self.socket.poll_recv_from(cx, &mut read_buffer) {
48 Poll::Ready(Ok(address)) => {
49 let n_buffer = read_buffer.filled().len();
50
51 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
52 buf.advance(n);
53
54 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
55 self.last_recv_address = Some(address);
56 Poll::Ready(Ok(()))
57 }
58 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
59 Poll::Pending => Poll::Pending,
60 }
61 } else {
62 let read_result = self.buffer.read(buf.initialize_unfilled());
63 let result = match read_result {
64 Ok(n) => {
65 buf.advance(n);
66 Ok(())
67 }
68 Err(err) => Err(err),
69 };
70 Poll::Ready(result)
71 }
72 }
73}
74
75struct UdpWrite {
76 socket: Arc<UdpSocket>,
77 dest: Option<std::net::SocketAddr>,
78 sequence: u8,
79}
80
81pub struct AsyncUdpConnection {
82 reader: Mutex<AsyncPeekReader<UdpRead>>,
83 writer: Mutex<UdpWrite>,
84 protocol_version: MavlinkVersion,
85 recv_any_version: bool,
86 server: bool,
87 #[cfg(feature = "signing")]
88 signing_data: Option<SigningData>,
89}
90
91impl AsyncUdpConnection {
92 fn new(
93 socket: UdpSocket,
94 server: bool,
95 dest: Option<std::net::SocketAddr>,
96 ) -> io::Result<Self> {
97 let socket = Arc::new(socket);
98 Ok(Self {
99 server,
100 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
101 socket: socket.clone(),
102 buffer: VecDeque::new(),
103 last_recv_address: None,
104 })),
105 writer: Mutex::new(UdpWrite {
106 socket,
107 dest,
108 sequence: 0,
109 }),
110 protocol_version: MavlinkVersion::V2,
111 recv_any_version: false,
112 #[cfg(feature = "signing")]
113 signing_data: None,
114 })
115 }
116}
117
118#[async_trait::async_trait]
119impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
120 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
121 let mut reader = self.reader.lock().await;
122 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
123 loop {
124 #[cfg(not(feature = "signing"))]
125 let result = read_versioned_msg_async(reader.deref_mut(), version).await;
126 #[cfg(feature = "signing")]
127 let result = read_versioned_msg_async_signed(
128 reader.deref_mut(),
129 version,
130 self.signing_data.as_ref(),
131 )
132 .await;
133 if self.server {
134 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
135 self.writer.lock().await.dest = addr;
136 }
137 }
138 if let ok @ Ok(..) = result {
139 return ok;
140 }
141 }
142 }
143
144 async fn recv_raw(&self) -> Result<MAVLinkMessageRaw, crate::error::MessageReadError> {
145 let mut reader = self.reader.lock().await;
146 let version = ReadVersion::from_async_conn_cfg::<_, M>(self);
147 loop {
148 #[cfg(not(feature = "signing"))]
149 let result =
150 read_versioned_raw_message_async::<M, _>(reader.deref_mut(), version).await;
151 #[cfg(feature = "signing")]
152 let result = read_versioned_raw_message_async_signed::<M, _>(
153 reader.deref_mut(),
154 version,
155 self.signing_data.as_ref(),
156 )
157 .await;
158 if self.server {
159 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
160 self.writer.lock().await.dest = addr;
161 }
162 }
163 if let ok @ Ok(..) = result {
164 return ok;
165 }
166 }
167 }
168
169 async fn send(
170 &self,
171 header: &MavHeader,
172 data: &M,
173 ) -> Result<usize, crate::error::MessageWriteError> {
174 let mut guard = self.writer.lock().await;
175 let state = &mut *guard;
176
177 let header = MavHeader {
178 sequence: state.sequence,
179 system_id: header.system_id,
180 component_id: header.component_id,
181 };
182
183 state.sequence = state.sequence.wrapping_add(1);
184
185 let len = if let Some(addr) = state.dest {
186 let mut buf = Vec::new();
187 #[cfg(not(feature = "signing"))]
188 write_versioned_msg_async(
189 &mut buf,
190 self.protocol_version,
191 header,
192 data,
193 #[cfg(feature = "signing")]
194 self.signing_data.as_ref(),
195 )
196 .await?;
197 #[cfg(feature = "signing")]
198 write_versioned_msg_signed(
199 &mut buf,
200 self.protocol_version,
201 header,
202 data,
203 #[cfg(feature = "signing")]
204 self.signing_data.as_ref(),
205 )?;
206 state.socket.send_to(&buf, addr).await?
207 } else {
208 0
209 };
210
211 Ok(len)
212 }
213
214 fn set_protocol_version(&mut self, version: MavlinkVersion) {
215 self.protocol_version = version;
216 }
217
218 fn protocol_version(&self) -> MavlinkVersion {
219 self.protocol_version
220 }
221
222 fn set_allow_recv_any_version(&mut self, allow: bool) {
223 self.recv_any_version = allow;
224 }
225
226 fn allow_recv_any_version(&self) -> bool {
227 self.recv_any_version
228 }
229
230 #[cfg(feature = "signing")]
231 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
232 self.signing_data = signing_data.map(SigningData::from_config);
233 }
234}
235
236#[async_trait]
237impl AsyncConnectable for UdpConfig {
238 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
239 where
240 M: Message + Sync + Send,
241 {
242 let (addr, server, dest): (&str, _, _) = match self.mode {
243 UdpMode::Udpin => (&self.address, true, None),
244 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
245 };
246 let socket = UdpSocket::bind(addr).await?;
247 if matches!(self.mode, UdpMode::Udpcast) {
248 socket.set_broadcast(true)?;
249 }
250 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use tokio::io::AsyncReadExt;
258
259 #[tokio::test]
260 async fn test_datagram_buffering() {
261 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
262 let mut udp_reader = UdpRead {
263 socket: receiver_socket.clone(),
264 buffer: VecDeque::new(),
265 last_recv_address: None,
266 };
267 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
268 sender_socket.connect("127.0.0.1:5001").await.unwrap();
269
270 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
271
272 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
273 assert_eq!(n_sent, datagram.len());
274 n_sent = sender_socket.send(&datagram).await.unwrap();
275 assert_eq!(n_sent, datagram.len());
276
277 let mut buf = [0u8; 30];
278
279 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
280 assert_eq!(n_read, 30);
281 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
282
283 n_read = udp_reader.read(&mut buf).await.unwrap();
284 assert_eq!(n_read, 20);
285 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
286
287 n_read = udp_reader.read(&mut buf).await.unwrap();
288 assert_eq!(n_read, 30);
289 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
290
291 n_read = udp_reader.read(&mut buf).await.unwrap();
292 assert_eq!(n_read, 20);
293 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
294 }
295}