1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 task::{Context, Poll, ready},
6};
7
8use bytes::Bytes;
9use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
10use thiserror::Error;
11
12use crate::{VarInt, connection::ConnectionRef};
13
14#[derive(Debug)]
31pub struct SendStream {
32 conn: ConnectionRef,
33 stream: StreamId,
34 is_0rtt: bool,
35}
36
37impl SendStream {
38 pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
39 Self {
40 conn,
41 stream,
42 is_0rtt,
43 }
44 }
45
46 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
53 Write { stream: self, buf }.await
54 }
55
56 pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
60 WriteAll { stream: self, buf }.await
61 }
62
63 pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
71 WriteChunks { stream: self, bufs }.await
72 }
73
74 pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
78 WriteChunk {
79 stream: self,
80 buf: [buf],
81 }
82 .await
83 }
84
85 pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
89 WriteAllChunks {
90 stream: self,
91 bufs,
92 offset: 0,
93 }
94 .await
95 }
96
97 fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
98 where
99 F: FnOnce(&mut proto::SendStream) -> Result<R, proto::WriteError>,
100 {
101 use proto::WriteError::*;
102 let mut conn = self.conn.state.lock("SendStream::poll_write");
103 if self.is_0rtt {
104 conn.check_0rtt()
105 .map_err(|()| WriteError::ZeroRttRejected)?;
106 }
107 if let Some(ref x) = conn.error {
108 return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
109 }
110
111 let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
112 Ok(result) => result,
113 Err(Blocked) => {
114 conn.blocked_writers.insert(self.stream, cx.waker().clone());
115 return Poll::Pending;
116 }
117 Err(Stopped(error_code)) => {
118 return Poll::Ready(Err(WriteError::Stopped(error_code)));
119 }
120 Err(ClosedStream) => {
121 return Poll::Ready(Err(WriteError::ClosedStream));
122 }
123 };
124
125 conn.wake();
126 Poll::Ready(Ok(result))
127 }
128
129 pub fn finish(&mut self) -> Result<(), ClosedStream> {
141 let mut conn = self.conn.state.lock("finish");
142 match conn.inner.send_stream(self.stream).finish() {
143 Ok(()) => {
144 conn.wake();
145 Ok(())
146 }
147 Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
148 Err(FinishError::Stopped(_)) => Ok(()),
151 }
152 }
153
154 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
164 let mut conn = self.conn.state.lock("SendStream::reset");
165 if self.is_0rtt && conn.check_0rtt().is_err() {
166 return Ok(());
167 }
168 conn.inner.send_stream(self.stream).reset(error_code)?;
169 conn.wake();
170 Ok(())
171 }
172
173 pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
181 let mut conn = self.conn.state.lock("SendStream::set_priority");
182 conn.inner.send_stream(self.stream).set_priority(priority)?;
183 Ok(())
184 }
185
186 pub fn priority(&self) -> Result<i32, ClosedStream> {
188 let mut conn = self.conn.state.lock("SendStream::priority");
189 conn.inner.send_stream(self.stream).priority()
190 }
191
192 pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
203 Stopped { stream: self }.await
204 }
205
206 fn poll_stopped(&mut self, cx: &mut Context) -> Poll<Result<Option<VarInt>, StoppedError>> {
207 let mut conn = self.conn.state.lock("SendStream::poll_stopped");
208
209 if self.is_0rtt {
210 conn.check_0rtt()
211 .map_err(|()| StoppedError::ZeroRttRejected)?;
212 }
213
214 match conn.inner.send_stream(self.stream).stopped() {
215 Err(_) => Poll::Ready(Ok(None)),
216 Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
217 Ok(None) => {
218 if let Some(e) = &conn.error {
219 return Poll::Ready(Err(e.clone().into()));
220 }
221 conn.stopped.insert(self.stream, cx.waker().clone());
222 Poll::Pending
223 }
224 }
225 }
226
227 pub fn id(&self) -> StreamId {
229 self.stream
230 }
231
232 pub fn poll_write(
240 self: Pin<&mut Self>,
241 cx: &mut Context,
242 buf: &[u8],
243 ) -> Poll<Result<usize, WriteError>> {
244 self.get_mut().execute_poll(cx, |stream| stream.write(buf))
245 }
246}
247
248#[cfg(feature = "futures-io")]
249impl futures_io::AsyncWrite for SendStream {
250 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
251 Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
252 }
253
254 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
255 Poll::Ready(Ok(()))
256 }
257
258 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
259 Poll::Ready(self.get_mut().finish().map_err(Into::into))
260 }
261}
262
263impl tokio::io::AsyncWrite for SendStream {
264 fn poll_write(
265 self: Pin<&mut Self>,
266 cx: &mut Context<'_>,
267 buf: &[u8],
268 ) -> Poll<io::Result<usize>> {
269 Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
270 }
271
272 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
273 Poll::Ready(Ok(()))
274 }
275
276 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
277 Poll::Ready(self.get_mut().finish().map_err(Into::into))
278 }
279}
280
281impl Drop for SendStream {
282 fn drop(&mut self) {
283 let mut conn = self.conn.state.lock("SendStream::drop");
284
285 conn.stopped.remove(&self.stream);
287 conn.blocked_writers.remove(&self.stream);
288
289 if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
290 return;
291 }
292 match conn.inner.send_stream(self.stream).finish() {
293 Ok(()) => conn.wake(),
294 Err(FinishError::Stopped(reason)) => {
295 if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
296 conn.wake();
297 }
298 }
299 Err(FinishError::ClosedStream) => {}
301 }
302 }
303}
304
305struct Stopped<'a> {
307 stream: &'a mut SendStream,
308}
309
310impl Future for Stopped<'_> {
311 type Output = Result<Option<VarInt>, StoppedError>;
312
313 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
314 self.get_mut().stream.poll_stopped(cx)
315 }
316}
317
318struct Write<'a> {
322 stream: &'a mut SendStream,
323 buf: &'a [u8],
324}
325
326impl Future for Write<'_> {
327 type Output = Result<usize, WriteError>;
328 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
329 let this = self.get_mut();
330 let buf = this.buf;
331 this.stream.execute_poll(cx, |s| s.write(buf))
332 }
333}
334
335struct WriteAll<'a> {
339 stream: &'a mut SendStream,
340 buf: &'a [u8],
341}
342
343impl Future for WriteAll<'_> {
344 type Output = Result<(), WriteError>;
345 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
346 let this = self.get_mut();
347 loop {
348 if this.buf.is_empty() {
349 return Poll::Ready(Ok(()));
350 }
351 let buf = this.buf;
352 let n = ready!(this.stream.execute_poll(cx, |s| s.write(buf)))?;
353 this.buf = &this.buf[n..];
354 }
355 }
356}
357
358struct WriteChunks<'a> {
362 stream: &'a mut SendStream,
363 bufs: &'a mut [Bytes],
364}
365
366impl Future for WriteChunks<'_> {
367 type Output = Result<Written, WriteError>;
368 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
369 let this = self.get_mut();
370 let bufs = &mut *this.bufs;
371 this.stream.execute_poll(cx, |s| s.write_chunks(bufs))
372 }
373}
374
375struct WriteChunk<'a> {
379 stream: &'a mut SendStream,
380 buf: [Bytes; 1],
381}
382
383impl Future for WriteChunk<'_> {
384 type Output = Result<(), WriteError>;
385 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
386 let this = self.get_mut();
387 loop {
388 if this.buf[0].is_empty() {
389 return Poll::Ready(Ok(()));
390 }
391 let bufs = &mut this.buf[..];
392 ready!(this.stream.execute_poll(cx, |s| s.write_chunks(bufs)))?;
393 }
394 }
395}
396
397struct WriteAllChunks<'a> {
401 stream: &'a mut SendStream,
402 bufs: &'a mut [Bytes],
403 offset: usize,
404}
405
406impl Future for WriteAllChunks<'_> {
407 type Output = Result<(), WriteError>;
408 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
409 let this = self.get_mut();
410 loop {
411 if this.offset == this.bufs.len() {
412 return Poll::Ready(Ok(()));
413 }
414 let bufs = &mut this.bufs[this.offset..];
415 let written = ready!(this.stream.execute_poll(cx, |s| s.write_chunks(bufs)))?;
416 this.offset += written.chunks;
417 }
418 }
419}
420
421#[derive(Debug, Error, Clone, PartialEq, Eq)]
423pub enum WriteError {
424 #[error("sending stopped by peer: error {0}")]
428 Stopped(VarInt),
429 #[error("connection lost")]
431 ConnectionLost(#[from] ConnectionError),
432 #[error("closed stream")]
434 ClosedStream,
435 #[error("0-RTT rejected")]
442 ZeroRttRejected,
443}
444
445impl From<ClosedStream> for WriteError {
446 #[inline]
447 fn from(_: ClosedStream) -> Self {
448 Self::ClosedStream
449 }
450}
451
452impl From<StoppedError> for WriteError {
453 fn from(x: StoppedError) -> Self {
454 match x {
455 StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
456 StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
457 }
458 }
459}
460
461impl From<WriteError> for io::Error {
462 fn from(x: WriteError) -> Self {
463 use WriteError::*;
464 let kind = match x {
465 Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
466 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
467 };
468 Self::new(kind, x)
469 }
470}
471
472#[derive(Debug, Error, Clone, PartialEq, Eq)]
474pub enum StoppedError {
475 #[error("connection lost")]
477 ConnectionLost(#[from] ConnectionError),
478 #[error("0-RTT rejected")]
485 ZeroRttRejected,
486}
487
488impl From<StoppedError> for io::Error {
489 fn from(x: StoppedError) -> Self {
490 use StoppedError::*;
491 let kind = match x {
492 ZeroRttRejected => io::ErrorKind::ConnectionReset,
493 ConnectionLost(_) => io::ErrorKind::NotConnected,
494 };
495 Self::new(kind, x)
496 }
497}