| Add support for Unix Sockets |
| |
| * add support for Unix Sockets to Stream/RefinedTcpStream. |
| * add a 'pub use' for ClientConnection and Stream |
| * convert ClientConnection::new to take Into<Stream> instead of two |
| RefinedTcpStreams. |
| * add a default SocketAddr value for if the stream returns an error. |
| |
| Pull Request: https://github.com/tiny-http/tiny-http/pull/187 |
| |
| --- a/src/client.rs |
| +++ b/src/client.rs |
| @@ -10,6 +10,7 @@ use std::str::FromStr; |
| use common::{HTTPVersion, Method}; |
| use util::RefinedTcpStream; |
| use util::{SequentialReader, SequentialReaderBuilder, SequentialWriterBuilder}; |
| +use util::Stream; |
| |
| use Request; |
| |
| @@ -17,7 +18,7 @@ use Request; |
| /// and return Request objects. |
| pub struct ClientConnection { |
| // address of the client |
| - remote_addr: IoResult<SocketAddr>, |
| + remote_addr: SocketAddr, |
| |
| // sequence of Readers to the stream, so that the data is not read in |
| // the wrong order |
| @@ -50,11 +51,11 @@ enum ReadError { |
| |
| impl ClientConnection { |
| /// Creates a new ClientConnection that takes ownership of the TcpStream. |
| - pub fn new( |
| - write_socket: RefinedTcpStream, |
| - mut read_socket: RefinedTcpStream, |
| - ) -> ClientConnection { |
| - let remote_addr = read_socket.peer_addr(); |
| + pub fn new<S>(stream: S) -> ClientConnection |
| + where S: Into<Stream> |
| + { |
| + let (mut read_socket, write_socket) = RefinedTcpStream::new(stream); |
| + let remote_addr = read_socket.peer_addr().unwrap_or(SocketAddr::from(([0,0,0,0], 0))); |
| let secure = read_socket.secure(); |
| |
| let mut source = SequentialReaderBuilder::new(BufReader::with_capacity(1024, read_socket)); |
| @@ -152,7 +153,7 @@ impl ClientConnection { |
| path, |
| version.clone(), |
| headers, |
| - *self.remote_addr.as_ref().unwrap(), |
| + self.remote_addr, |
| data_source, |
| writer, |
| ) |
| diff --git a/src/lib.rs b/src/lib.rs |
| index 40b5491..60e8dcc 100644 |
| --- a/src/lib.rs |
| +++ b/src/lib.rs |
| @@ -117,8 +117,9 @@ use std::sync::Arc; |
| use std::thread; |
| use std::time::Duration; |
| |
| -use client::ClientConnection; |
| +pub use client::ClientConnection; |
| use util::MessagesQueue; |
| +pub use util::Stream; |
| |
| pub use common::{HTTPVersion, Header, HeaderField, Method, StatusCode}; |
| pub use request::{ReadWrite, Request}; |
| @@ -297,27 +298,24 @@ impl Server { |
| while !inside_close_trigger.load(Relaxed) { |
| let new_client = match server.accept() { |
| Ok((sock, _)) => { |
| - use util::RefinedTcpStream; |
| - let (read_closable, write_closable) = match ssl { |
| - None => RefinedTcpStream::new(sock), |
| + let stream = match ssl { |
| + None => sock, |
| #[cfg(feature = "ssl")] |
| Some(ref ssl) => { |
| let ssl = openssl::ssl::Ssl::new(ssl).expect("Couldn't create ssl"); |
| // trying to apply SSL over the connection |
| // if an error occurs, we just close the socket and resume listening |
| - let sock = match ssl.accept(sock) { |
| + match ssl.accept(sock) { |
| Ok(s) => s, |
| - Err(_) => continue, |
| - }; |
| - |
| - RefinedTcpStream::new(sock) |
| - } |
| + Err(_) => continue |
| + } |
| + }, |
| #[cfg(not(feature = "ssl"))] |
| Some(_) => unreachable!(), |
| }; |
| |
| - Ok(ClientConnection::new(write_closable, read_closable)) |
| - } |
| + Ok(ClientConnection::new(stream)) |
| + }, |
| Err(e) => Err(e), |
| }; |
| |
| diff --git a/src/util/mod.rs b/src/util/mod.rs |
| index 8abfb64..d775ee6 100644 |
| --- a/src/util/mod.rs |
| +++ b/src/util/mod.rs |
| @@ -1,7 +1,7 @@ |
| pub use self::custom_stream::CustomStream; |
| pub use self::equal_reader::EqualReader; |
| pub use self::messages_queue::MessagesQueue; |
| -pub use self::refined_tcp_stream::RefinedTcpStream; |
| +pub use self::refined_tcp_stream::{RefinedTcpStream, Stream}; |
| pub use self::sequential::{SequentialReader, SequentialReaderBuilder}; |
| pub use self::sequential::{SequentialWriter, SequentialWriterBuilder}; |
| pub use self::task_pool::TaskPool; |
| diff --git a/src/util/refined_tcp_stream.rs b/src/util/refined_tcp_stream.rs |
| index 0c031a9..942a017 100644 |
| --- a/src/util/refined_tcp_stream.rs |
| +++ b/src/util/refined_tcp_stream.rs |
| @@ -1,6 +1,7 @@ |
| use std::io::Result as IoResult; |
| use std::io::{Read, Write}; |
| use std::net::{Shutdown, SocketAddr, TcpStream}; |
| +use std::os::unix::net::UnixStream; |
| |
| #[cfg(feature = "ssl")] |
| use openssl::ssl::SslStream; |
| @@ -17,6 +18,7 @@ pub enum Stream { |
| Http(TcpStream), |
| #[cfg(feature = "ssl")] |
| Https(Arc<Mutex<SslStream<TcpStream>>>), |
| + Unix(UnixStream), |
| } |
| |
| impl From<TcpStream> for Stream { |
| @@ -34,6 +36,13 @@ impl From<SslStream<TcpStream>> for Stream { |
| } |
| } |
| |
| +impl From<UnixStream> for Stream { |
| + #[inline] |
| + fn from(stream: UnixStream) -> Stream { |
| + Stream::Unix(stream) |
| + } |
| +} |
| + |
| impl RefinedTcpStream { |
| pub fn new<S>(stream: S) -> (RefinedTcpStream, RefinedTcpStream) |
| where |
| @@ -45,6 +54,7 @@ impl RefinedTcpStream { |
| Stream::Http(ref stream) => Stream::Http(stream.try_clone().unwrap()), |
| #[cfg(feature = "ssl")] |
| Stream::Https(ref stream) => Stream::Https(stream.clone()), |
| + Stream::Unix(ref stream) => Stream::Unix(stream.try_clone().unwrap()), |
| }; |
| |
| let read = RefinedTcpStream { |
| @@ -69,6 +79,7 @@ impl RefinedTcpStream { |
| Stream::Http(_) => false, |
| #[cfg(feature = "ssl")] |
| Stream::Https(_) => true, |
| + Stream::Unix(_) => false, |
| } |
| } |
| |
| @@ -77,6 +88,7 @@ impl RefinedTcpStream { |
| Stream::Http(ref mut stream) => stream.peer_addr(), |
| #[cfg(feature = "ssl")] |
| Stream::Https(ref mut stream) => stream.lock().unwrap().get_ref().peer_addr(), |
| + Stream::Unix(_) => Err(std::io::Error::new(std::io::ErrorKind::Other, "Peer addresses are not supported for Unix sockets")), |
| } |
| } |
| } |
| @@ -94,6 +106,7 @@ impl Drop for RefinedTcpStream { |
| .get_mut() |
| .shutdown(Shutdown::Read) |
| .ok(), |
| + Stream::Unix(ref mut stream) => stream.shutdown(Shutdown::Read).ok(), |
| }; |
| } |
| |
| @@ -108,6 +121,7 @@ impl Drop for RefinedTcpStream { |
| .get_mut() |
| .shutdown(Shutdown::Write) |
| .ok(), |
| + Stream::Unix(ref mut stream) => stream.shutdown(Shutdown::Write).ok(), |
| }; |
| } |
| } |
| @@ -119,6 +133,7 @@ impl Read for RefinedTcpStream { |
| Stream::Http(ref mut stream) => stream.read(buf), |
| #[cfg(feature = "ssl")] |
| Stream::Https(ref mut stream) => stream.lock().unwrap().read(buf), |
| + Stream::Unix(ref mut stream) => stream.read(buf), |
| } |
| } |
| } |
| @@ -129,6 +144,7 @@ impl Write for RefinedTcpStream { |
| Stream::Http(ref mut stream) => stream.write(buf), |
| #[cfg(feature = "ssl")] |
| Stream::Https(ref mut stream) => stream.lock().unwrap().write(buf), |
| + Stream::Unix(ref mut stream) => stream.write(buf), |
| } |
| } |
| |
| @@ -137,6 +153,7 @@ impl Write for RefinedTcpStream { |
| Stream::Http(ref mut stream) => stream.flush(), |
| #[cfg(feature = "ssl")] |
| Stream::Https(ref mut stream) => stream.lock().unwrap().flush(), |
| + Stream::Unix(ref mut stream) => stream.flush(), |
| } |
| } |
| } |