mavlink_core/async_connection/
udp.rs1use core::{ops::DerefMut, task::Poll};
4use std::{collections::VecDeque, io::Read, sync::Arc};
5
6use async_trait::async_trait;
7use tokio::{
8 io::{self, AsyncRead, ReadBuf},
9 net::UdpSocket,
10 sync::Mutex,
11};
12
13use crate::{
14 async_peek_reader::AsyncPeekReader,
15 connectable::{UdpConnectable, UdpMode},
16 MavHeader, MavlinkVersion, Message,
17};
18
19use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
20
21#[cfg(not(feature = "signing"))]
22use crate::{read_versioned_msg_async, write_versioned_msg_async};
23#[cfg(feature = "signing")]
24use crate::{
25 read_versioned_msg_async_signed, write_versioned_msg_signed, SigningConfig, SigningData,
26};
27
28struct UdpRead {
29 socket: Arc<UdpSocket>,
30 buffer: VecDeque<u8>,
31 last_recv_address: Option<std::net::SocketAddr>,
32}
33
34const MTU_SIZE: usize = 1500;
35impl AsyncRead for UdpRead {
36 fn poll_read(
37 mut self: core::pin::Pin<&mut Self>,
38 cx: &mut core::task::Context<'_>,
39 buf: &mut io::ReadBuf<'_>,
40 ) -> Poll<io::Result<()>> {
41 if self.buffer.is_empty() {
42 let mut read_buffer = [0u8; MTU_SIZE];
43 let mut read_buffer = ReadBuf::new(&mut read_buffer);
44
45 match self.socket.poll_recv_from(cx, &mut read_buffer) {
46 Poll::Ready(Ok(address)) => {
47 let n_buffer = read_buffer.filled().len();
48
49 let n = (&read_buffer.filled()[0..n_buffer]).read(buf.initialize_unfilled())?;
50 buf.advance(n);
51
52 self.buffer.extend(&read_buffer.filled()[n..n_buffer]);
53 self.last_recv_address = Some(address);
54 Poll::Ready(Ok(()))
55 }
56 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
57 Poll::Pending => Poll::Pending,
58 }
59 } else {
60 let read_result = self.buffer.read(buf.initialize_unfilled());
61 let result = match read_result {
62 Ok(n) => {
63 buf.advance(n);
64 Ok(())
65 }
66 Err(err) => Err(err),
67 };
68 Poll::Ready(result)
69 }
70 }
71}
72
73struct UdpWrite {
74 socket: Arc<UdpSocket>,
75 dest: Option<std::net::SocketAddr>,
76 sequence: u8,
77}
78
79pub struct AsyncUdpConnection {
80 reader: Mutex<AsyncPeekReader<UdpRead>>,
81 writer: Mutex<UdpWrite>,
82 protocol_version: MavlinkVersion,
83 server: bool,
84 #[cfg(feature = "signing")]
85 signing_data: Option<SigningData>,
86}
87
88impl AsyncUdpConnection {
89 fn new(
90 socket: UdpSocket,
91 server: bool,
92 dest: Option<std::net::SocketAddr>,
93 ) -> io::Result<Self> {
94 let socket = Arc::new(socket);
95 Ok(Self {
96 server,
97 reader: Mutex::new(AsyncPeekReader::new(UdpRead {
98 socket: socket.clone(),
99 buffer: VecDeque::new(),
100 last_recv_address: None,
101 })),
102 writer: Mutex::new(UdpWrite {
103 socket,
104 dest,
105 sequence: 0,
106 }),
107 protocol_version: MavlinkVersion::V2,
108 #[cfg(feature = "signing")]
109 signing_data: None,
110 })
111 }
112}
113
114#[async_trait::async_trait]
115impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
116 async fn recv(&self) -> Result<(MavHeader, M), crate::error::MessageReadError> {
117 let mut reader = self.reader.lock().await;
118
119 loop {
120 #[cfg(not(feature = "signing"))]
121 let result = read_versioned_msg_async(reader.deref_mut(), self.protocol_version).await;
122 #[cfg(feature = "signing")]
123 let result = read_versioned_msg_async_signed(
124 reader.deref_mut(),
125 self.protocol_version,
126 self.signing_data.as_ref(),
127 )
128 .await;
129 if self.server {
130 if let addr @ Some(_) = reader.reader_ref().last_recv_address {
131 self.writer.lock().await.dest = addr;
132 }
133 }
134 if let ok @ Ok(..) = result {
135 return ok;
136 }
137 }
138 }
139
140 async fn send(
141 &self,
142 header: &MavHeader,
143 data: &M,
144 ) -> Result<usize, crate::error::MessageWriteError> {
145 let mut guard = self.writer.lock().await;
146 let state = &mut *guard;
147
148 let header = MavHeader {
149 sequence: state.sequence,
150 system_id: header.system_id,
151 component_id: header.component_id,
152 };
153
154 state.sequence = state.sequence.wrapping_add(1);
155
156 let len = if let Some(addr) = state.dest {
157 let mut buf = Vec::new();
158 #[cfg(not(feature = "signing"))]
159 write_versioned_msg_async(
160 &mut buf,
161 self.protocol_version,
162 header,
163 data,
164 #[cfg(feature = "signing")]
165 self.signing_data.as_ref(),
166 )
167 .await?;
168 #[cfg(feature = "signing")]
169 write_versioned_msg_signed(
170 &mut buf,
171 self.protocol_version,
172 header,
173 data,
174 #[cfg(feature = "signing")]
175 self.signing_data.as_ref(),
176 )?;
177 state.socket.send_to(&buf, addr).await?
178 } else {
179 0
180 };
181
182 Ok(len)
183 }
184
185 fn set_protocol_version(&mut self, version: MavlinkVersion) {
186 self.protocol_version = version;
187 }
188
189 fn get_protocol_version(&self) -> MavlinkVersion {
190 self.protocol_version
191 }
192
193 #[cfg(feature = "signing")]
194 fn setup_signing(&mut self, signing_data: Option<SigningConfig>) {
195 self.signing_data = signing_data.map(SigningData::from_config)
196 }
197}
198
199#[async_trait]
200impl AsyncConnectable for UdpConnectable {
201 async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
202 where
203 M: Message + Sync + Send,
204 {
205 let (addr, server, dest): (&str, _, _) = match self.mode {
206 UdpMode::Udpin => (&self.address, true, None),
207 _ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
208 };
209 let socket = UdpSocket::bind(addr).await?;
210 if matches!(self.mode, UdpMode::Udpcast) {
211 socket.set_broadcast(true)?;
212 }
213 Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use io::AsyncReadExt;
221
222 #[tokio::test]
223 async fn test_datagram_buffering() {
224 let receiver_socket = Arc::new(UdpSocket::bind("127.0.0.1:5001").await.unwrap());
225 let mut udp_reader = UdpRead {
226 socket: receiver_socket.clone(),
227 buffer: VecDeque::new(),
228 last_recv_address: None,
229 };
230 let sender_socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
231 sender_socket.connect("127.0.0.1:5001").await.unwrap();
232
233 let datagram: Vec<u8> = (0..50).collect::<Vec<_>>();
234
235 let mut n_sent = sender_socket.send(&datagram).await.unwrap();
236 assert_eq!(n_sent, datagram.len());
237 n_sent = sender_socket.send(&datagram).await.unwrap();
238 assert_eq!(n_sent, datagram.len());
239
240 let mut buf = [0u8; 30];
241
242 let mut n_read = udp_reader.read(&mut buf).await.unwrap();
243 assert_eq!(n_read, 30);
244 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
245
246 n_read = udp_reader.read(&mut buf).await.unwrap();
247 assert_eq!(n_read, 20);
248 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
249
250 n_read = udp_reader.read(&mut buf).await.unwrap();
251 assert_eq!(n_read, 30);
252 assert_eq!(&buf[0..n_read], (0..30).collect::<Vec<_>>().as_slice());
253
254 n_read = udp_reader.read(&mut buf).await.unwrap();
255 assert_eq!(n_read, 20);
256 assert_eq!(&buf[0..n_read], (30..50).collect::<Vec<_>>().as_slice());
257 }
258}