blob: 78a6a124fbd1c1942c8de9577d1b696b9823137e [file] [log] [blame]
Add support for Unix Sockets
* add support for Unix Sockets to Stream/RefinedTcpStream.
* add a 'pub use' for ClientConnection and Stream
* convert ClientConnection::new to take Into<Stream> instead of two
RefinedTcpStreams.
* add a default SocketAddr value for if the stream returns an error.
Pull Request: https://github.com/tiny-http/tiny-http/pull/187
--- a/src/client.rs
+++ b/src/client.rs
@@ -10,6 +10,7 @@ use std::str::FromStr;
use common::{HTTPVersion, Method};
use util::RefinedTcpStream;
use util::{SequentialReader, SequentialReaderBuilder, SequentialWriterBuilder};
+use util::Stream;
use Request;
@@ -17,7 +18,7 @@ use Request;
/// and return Request objects.
pub struct ClientConnection {
// address of the client
- remote_addr: IoResult<SocketAddr>,
+ remote_addr: SocketAddr,
// sequence of Readers to the stream, so that the data is not read in
// the wrong order
@@ -50,11 +51,11 @@ enum ReadError {
impl ClientConnection {
/// Creates a new ClientConnection that takes ownership of the TcpStream.
- pub fn new(
- write_socket: RefinedTcpStream,
- mut read_socket: RefinedTcpStream,
- ) -> ClientConnection {
- let remote_addr = read_socket.peer_addr();
+ pub fn new<S>(stream: S) -> ClientConnection
+ where S: Into<Stream>
+ {
+ let (mut read_socket, write_socket) = RefinedTcpStream::new(stream);
+ let remote_addr = read_socket.peer_addr().unwrap_or(SocketAddr::from(([0,0,0,0], 0)));
let secure = read_socket.secure();
let mut source = SequentialReaderBuilder::new(BufReader::with_capacity(1024, read_socket));
@@ -152,7 +153,7 @@ impl ClientConnection {
path,
version.clone(),
headers,
- *self.remote_addr.as_ref().unwrap(),
+ self.remote_addr,
data_source,
writer,
)
diff --git a/src/lib.rs b/src/lib.rs
index 40b5491..60e8dcc 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -117,8 +117,9 @@ use std::sync::Arc;
use std::thread;
use std::time::Duration;
-use client::ClientConnection;
+pub use client::ClientConnection;
use util::MessagesQueue;
+pub use util::Stream;
pub use common::{HTTPVersion, Header, HeaderField, Method, StatusCode};
pub use request::{ReadWrite, Request};
@@ -297,27 +298,24 @@ impl Server {
while !inside_close_trigger.load(Relaxed) {
let new_client = match server.accept() {
Ok((sock, _)) => {
- use util::RefinedTcpStream;
- let (read_closable, write_closable) = match ssl {
- None => RefinedTcpStream::new(sock),
+ let stream = match ssl {
+ None => sock,
#[cfg(feature = "ssl")]
Some(ref ssl) => {
let ssl = openssl::ssl::Ssl::new(ssl).expect("Couldn't create ssl");
// trying to apply SSL over the connection
// if an error occurs, we just close the socket and resume listening
- let sock = match ssl.accept(sock) {
+ match ssl.accept(sock) {
Ok(s) => s,
- Err(_) => continue,
- };
-
- RefinedTcpStream::new(sock)
- }
+ Err(_) => continue
+ }
+ },
#[cfg(not(feature = "ssl"))]
Some(_) => unreachable!(),
};
- Ok(ClientConnection::new(write_closable, read_closable))
- }
+ Ok(ClientConnection::new(stream))
+ },
Err(e) => Err(e),
};
diff --git a/src/util/mod.rs b/src/util/mod.rs
index 8abfb64..d775ee6 100644
--- a/src/util/mod.rs
+++ b/src/util/mod.rs
@@ -1,7 +1,7 @@
pub use self::custom_stream::CustomStream;
pub use self::equal_reader::EqualReader;
pub use self::messages_queue::MessagesQueue;
-pub use self::refined_tcp_stream::RefinedTcpStream;
+pub use self::refined_tcp_stream::{RefinedTcpStream, Stream};
pub use self::sequential::{SequentialReader, SequentialReaderBuilder};
pub use self::sequential::{SequentialWriter, SequentialWriterBuilder};
pub use self::task_pool::TaskPool;
diff --git a/src/util/refined_tcp_stream.rs b/src/util/refined_tcp_stream.rs
index 0c031a9..942a017 100644
--- a/src/util/refined_tcp_stream.rs
+++ b/src/util/refined_tcp_stream.rs
@@ -1,6 +1,7 @@
use std::io::Result as IoResult;
use std::io::{Read, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
+use std::os::unix::net::UnixStream;
#[cfg(feature = "ssl")]
use openssl::ssl::SslStream;
@@ -17,6 +18,7 @@ pub enum Stream {
Http(TcpStream),
#[cfg(feature = "ssl")]
Https(Arc<Mutex<SslStream<TcpStream>>>),
+ Unix(UnixStream),
}
impl From<TcpStream> for Stream {
@@ -34,6 +36,13 @@ impl From<SslStream<TcpStream>> for Stream {
}
}
+impl From<UnixStream> for Stream {
+ #[inline]
+ fn from(stream: UnixStream) -> Stream {
+ Stream::Unix(stream)
+ }
+}
+
impl RefinedTcpStream {
pub fn new<S>(stream: S) -> (RefinedTcpStream, RefinedTcpStream)
where
@@ -45,6 +54,7 @@ impl RefinedTcpStream {
Stream::Http(ref stream) => Stream::Http(stream.try_clone().unwrap()),
#[cfg(feature = "ssl")]
Stream::Https(ref stream) => Stream::Https(stream.clone()),
+ Stream::Unix(ref stream) => Stream::Unix(stream.try_clone().unwrap()),
};
let read = RefinedTcpStream {
@@ -69,6 +79,7 @@ impl RefinedTcpStream {
Stream::Http(_) => false,
#[cfg(feature = "ssl")]
Stream::Https(_) => true,
+ Stream::Unix(_) => false,
}
}
@@ -77,6 +88,7 @@ impl RefinedTcpStream {
Stream::Http(ref mut stream) => stream.peer_addr(),
#[cfg(feature = "ssl")]
Stream::Https(ref mut stream) => stream.lock().unwrap().get_ref().peer_addr(),
+ Stream::Unix(_) => Err(std::io::Error::new(std::io::ErrorKind::Other, "Peer addresses are not supported for Unix sockets")),
}
}
}
@@ -94,6 +106,7 @@ impl Drop for RefinedTcpStream {
.get_mut()
.shutdown(Shutdown::Read)
.ok(),
+ Stream::Unix(ref mut stream) => stream.shutdown(Shutdown::Read).ok(),
};
}
@@ -108,6 +121,7 @@ impl Drop for RefinedTcpStream {
.get_mut()
.shutdown(Shutdown::Write)
.ok(),
+ Stream::Unix(ref mut stream) => stream.shutdown(Shutdown::Write).ok(),
};
}
}
@@ -119,6 +133,7 @@ impl Read for RefinedTcpStream {
Stream::Http(ref mut stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Stream::Https(ref mut stream) => stream.lock().unwrap().read(buf),
+ Stream::Unix(ref mut stream) => stream.read(buf),
}
}
}
@@ -129,6 +144,7 @@ impl Write for RefinedTcpStream {
Stream::Http(ref mut stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Stream::Https(ref mut stream) => stream.lock().unwrap().write(buf),
+ Stream::Unix(ref mut stream) => stream.write(buf),
}
}
@@ -137,6 +153,7 @@ impl Write for RefinedTcpStream {
Stream::Http(ref mut stream) => stream.flush(),
#[cfg(feature = "ssl")]
Stream::Https(ref mut stream) => stream.lock().unwrap().flush(),
+ Stream::Unix(ref mut stream) => stream.flush(),
}
}
}