Build an ngrok Alternative with Rust and Yamux

Network tunnel architecture diagram showing client and server communication through Yamux multiplexer

Technical Abstract

Network tunneling is a critical component in modern distributed systems, allowing for secure and efficient communication between network endpoints. This technical guide explores implementing a network tunnel in Rust using Yamux (Yet Another Multiplexer), focusing on the implementation from the provided codebase that enables multiple logical streams over a single connection. We'll examine how the codebase implements robust tunneling with TLS support, protocol handlers, and efficient stream management.

Prerequisites:

  • Rust 1.75.0 or later
  • Basic understanding of network programming
  • Familiarity with async Rust
  • Cargo package manager

Technical Highlights

  • Implementation of bidirectional multiplexed tunnels using Yamux
  • Async/await based network programming with Tokio
  • TLS support with SNI-based routing
  • Protocol handlers for TCP, TLS, and HTTP(S)
  • Comprehensive error handling and connection recovery

Core Content Structure

Technical Overview

Mermaid diagram
Click to expand

System Architecture

Mermaid diagram
Click to expand

Implementation Details

First, let's look at the project dependencies from the codebase:

[package]
name = "tunnel"
version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "=1.38", features = ["full"] }
log = "0.4.22"
yamux = "0.13.3"
tokio-util = { version = "0.7.12", features = ["compat", "codec"] }
futures = { version = "0.3.30", features = ["async-await"] }
tokio-yamux = "0.3.8"
anyhow = "1.0.89"


The core Yamux implementation in the codebase:

use std::sync::Arc;
use tokio::net::TcpStream;
use yamux::{Config, Connection, Mode};

pub struct YamuxListener<H> {
    address: String,
    protocol_handler: Arc<H>,
}

impl<H> YamuxListener<H>
where
    H: ProtocolHandler,
{
    pub fn new(address: String, protocol_handler: Arc<H>) -> Self {
        Self {
            address,
            protocol_handler,
        }
    }

    pub async fn run(&self) -> anyhow::Result<()> {
        let listener = TcpListener::bind(&self.address).await?;
        info!("Yamux Listener started on {}", self.address);

        loop {
            let (socket, addr) = listener.accept().await?;
            let protocol_handler = self.protocol_handler.clone();
            
            tokio::spawn(async move {
                let yamux = multiplexing::Yamux::upgrade_connection(
                    socket,
                    multiplexing::ConnectionDirection::Outbound,
                )?;
                
                let yamux_control = yamux.get_yamux_control().clone();
                let mut incoming = yamux.into_incoming();
                
                if let Err(e) = protocol_handler
                    .initialize(yamux_control.clone(), &mut incoming)
                    .await
                {
                    error!("Error during protocol initialization: {}", e);
                    return;
                }
            });
        }
    }
}

Component Breakdown

The codebase implements several key components:

  1. Yamux Connection Management:

src/multiplexing/yamux.rs

    /// Upgrade the underlying socket to use yamux
    pub fn upgrade_connection<TSocket>(
        socket: TSocket,
        direction: crate::multiplexing::direction::ConnectionDirection,
    ) -> io::Result<Self>
    where
        TSocket: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
    {
        let mode = match direction {
            crate::multiplexing::direction::ConnectionDirection::Inbound => Mode::Server,
            crate::multiplexing::direction::ConnectionDirection::Outbound => Mode::Client,
        };

        let config = yamux::Config::default();

        let substream_counter = AtomicRefCounter::new();
        let connection = yamux::Connection::new(socket.compat(), config, mode);
        let (control, incoming) =
            Self::spawn_incoming_stream_worker(connection, substream_counter.clone());

        Ok(Self {
            control,
            incoming,
            substream_counter,
        })
    }
  1. Protocol Handlers:

src/handlers.rs

// Individual handler for each TLS connection
pub struct TlsConnectionHandler {
    mux_control: Control,
    pub domains: Vec<String>,
}

impl TlsConnectionHandler {
    pub fn new(mux_control: Control, domains: Vec<String>) -> Self {
        Self {
            mux_control,
            domains,
        }
    }
}

#[async_trait]
impl TlsHandler for TlsConnectionHandler {
    fn domains(&self) -> &Vec<String> {
        &self.domains
    }

    fn needs_handshake(&self) -> bool {
        false
    }

    async fn handle_socket(
        &self,
        mut stream: Box<dyn AsyncReadWrite + Unpin + Send>,
        _domain: String,
    ) -> anyhow::Result<()> {
        // Handle the stream without TLS handshake
        let mut yamux_stream = self.mux_control.clone().open_stream().await?;
        let stream_id = StreamId::generate();
        let control_packet = ControlPacket::Stream(stream_id);
        yamux_stream.write_all(&control_packet.serialize()).await?;
        yamux_stream.flush().await?;

        tokio::spawn(async move {
            match tokio::io::copy_bidirectional(&mut stream, &mut yamux_stream).await {
                Ok(_) => {
                    debug!("Stream closed normally");
                }
                Err(ref e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
                    debug!("Connection reset by peer (os error 104), treating as normal closure");
                }
                Err(ref e) if e.kind() == std::io::ErrorKind::NotConnected => {
                    debug!("Connection reset by peer (os error 104), treating as normal closure");
                }
                Err(e) => {
                    error!("Stream error while handling tcp: {}", e);
                }
            }
        });

        Ok(())
    }

    async fn check_connection(&self) -> anyhow::Result<()> {
        Ok(()) // Default implementation
    }
}
  1. TLS Integration:

src/utils/tls.rs

pub fn load_tls_config(cert_path: &str, key_path: &str) -> anyhow::Result<ServerConfig> {
    let certs = load_certs(cert_path)?;
    let key = load_private_key(key_path)?;

    let mut config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(certs, key)?;
    config.alpn_protocols = vec![HTTP_1_1.to_vec(), HTTP_2.to_vec(), ACME_TLS_ALPN_NAME.to_vec()];
    Ok(config)
}

Security Considerations

The codebase implements several security features:

  1. TLS Configuration:
pub fn load_tls_config(cert_path: &str, key_path: &str) -> anyhow::Result<ServerConfig> {
    let certs = load_certs(cert_path)?;
    let key = load_private_key(key_path)?;

    let mut config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(certs, key)?;
    config.alpn_protocols = vec![HTTP_1_1.to_vec(), HTTP_2.to_vec(), ACME_TLS_ALPN_NAME.to_vec()];
    Ok(config)
}
  1. SNI Support:

src/streams/sni.rs

pub struct SniStream {
    pub sni: String,
    buffer: Cursor<Vec<u8>>,
    pub stream: TcpStream,
}

impl SniStream {
    pub fn new(sni: String, buffer: Vec<u8>, stream: TcpStream) -> Self {
        SniStream {
            sni,
            buffer: Cursor::new(buffer),
            stream,
        }
    }
    pub fn get_stream(&mut self) -> &mut TcpStream {
        &mut self.stream
    }
}

Performance Optimization

The codebase includes several performance optimizations:

  1. Efficient Stream Management:

src/multiplexing/yamux.rs

    type Item = Substream;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match futures::ready!(Pin::new(&mut self.inner).poll_recv(cx)) {
            Some(stream) => Poll::Ready(Some(Substream {
                stream: stream.compat(),
                _counter_guard: self.substream_counter.new_guard(),
            })),
            None => Poll::Ready(None),
        }
    }
}
  1. Connection Pooling:

src/handlers.rs

        // Start the listener
        handler.start_listener().await?;
        let mut control = control.clone();
        // Send the assigned port to the client
        // Open a new Yamux stream to send the port
        let mut yamux_stream = control.open_stream().await?;
        // write ServerInfo
        let stream_id = StreamId::generate();
        let server_info =
            ControlPacket::ServerInfo(stream_id, VERSION.to_string(), REGION.to_string());
        let serialized = server_info.serialize();
        yamux_stream.write_all(serialized.as_slice()).await?;
        yamux_stream.flush().await?;
        info!("tcp: Sent server info");

        // wait for ServerInfoAck
        let server_info_ack = wait_for_packet(&mut yamux_stream).await?;
        match server_info_ack {
            ControlPacket::ServerInfoAck(_) => {
                info!("tcp: Received ServerInfoAck");
            }
            _ => {
                error!("Unexpected packet type");
                return Err(anyhow::anyhow!("Unexpected packet type"));
            }
        }

        let control_packet = ControlPacket::PortAssignment(stream_id, assigned_port);
        yamux_stream.write_all(&control_packet.serialize()).await?;
        yamux_stream.flush().await?;

        tokio::spawn(async move {
            info!("tcp: Listening for ping");
            loop {
                let mut buffer = [0; 1024];
                let n = yamux_stream.read(&mut buffer).await.unwrap();
                if n == 0 {
                    error!("Connection closed");
                    break;
                }
                let packet = ControlPacket::deserialize(&buffer[..n]).unwrap();
                match packet {
                    ControlPacket::Ping(_) => {
                        // reply with pong
                        let now_as_millis = std::time::SystemTime::now()
                            .duration_since(std::time::UNIX_EPOCH)
                            .unwrap()
                            .as_millis() as u64;
                        let pong_packet = ControlPacket::Pong(stream_id.clone(), now_as_millis);
                        let serialized = pong_packet.serialize();
                        info!("tcp: Sending pong {:?}", serialized);
                        yamux_stream.write_all(serialized.as_slice()).await.unwrap();
                        yamux_stream.flush().await.unwrap();
                    }
                    _ => {}
                }
            }
        });

Technical Documentation

Error Handling

The codebase implements comprehensive error handling:

src/multiplexing/error.rs

#[derive(Debug, Error, Clone)]
pub enum YamuxControlError {
    #[error("Yamux connection error: {0}")]
    ConnectionError(String),
    #[error("Yamux connection closed")]
    ConnectionClosed,
    #[error("Yamux request send error: {0}")]
    RequestSendError(String),
    #[error("Yamux request send error: {0}")]
    RequestCanceled(#[from] Canceled),
}

Testing Strategies

The codebase includes extensive testing:

src/tests.rs

#[tokio::test]
async fn test_client_reconnection_performance() {
    // println!("Running test");
    const NUM_ITERATIONS: usize = 100;
    println!("Running {} iterations", NUM_ITERATIONS);
    let mut connection_times = Vec::with_capacity(NUM_ITERATIONS);
    let domain = "test.localho.st";
    for i in 0..NUM_ITERATIONS {
        println!("Iteration {}", i);
        let start_time = std::time::Instant::now();
        // Create client config
        let config = TunnelClientConfig {
            protocol: TunnelProtocol::Tls,
            server_address: "eu.tun.kfs.es:12345".to_string(),
            // server_address: "127.0.0.1:12345".to_string(),
            local_address: "127.0.0.1:3000".to_string(),
            domains: vec![domain.to_string()],
        };
        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
        // Connect
        let (status, management, open_rx) = run_tunnel_client(config.clone(), shutdown_rx)
            .await
            .unwrap();

        // Record connection time
        let connection_time = start_time.elapsed();
        connection_times.push(connection_time);
        // Disconnect
        let _ = shutdown_tx.send(());
        // wait until open_rx is false
        while open_rx.borrow().clone() {
            tokio::time::sleep(Duration::from_millis(10)).await;
        }
        // Wait for complete shutdown
        tokio::time::sleep(Duration::from_millis(100)).await;

        println!(
            "Iteration {}: Connection time: {:?}",
            i + 1,
            connection_time
        );
    }

    // Calculate and print statistics
    let total_time: Duration = connection_times.iter().sum();
    let avg_time = total_time / NUM_ITERATIONS as u32;

    let max_time = connection_times.iter().max().unwrap();
    let min_time = connection_times.iter().min().unwrap();

    println!("\nPerformance Test Results:");
    println!("Total iterations: {}", NUM_ITERATIONS);
    println!("Average connection time: {:?}", avg_time);
    println!("Maximum connection time: {:?}", max_time);
    println!("Minimum connection time: {:?}", min_time);

    // Optional: Assert that the average connection time is within acceptable bounds
    assert!(
        avg_time < Duration::from_secs(1),
        "Average connection time too high"
    );
}

Technical Conclusion

The codebase demonstrates a production-ready implementation of network tunneling using Rust and Yamux, featuring:

  • Robust protocol handling for TCP, TLS, and HTTP(S)
  • Comprehensive error handling and recovery
  • Efficient stream multiplexing
  • TLS support with SNI routing
  • Extensive testing coverage

Future improvements could include:

  • Enhanced metrics collection
  • Additional protocol handlers
  • Advanced compression support
  • Distributed connection pooling

Technical FAQ

  1. Q: How does the codebase handle TLS connections? A: Through the TLS handler implementation with SNI support and ALPN protocols.

  2. Q: What protocols are supported? A: TCP, TLS, and HTTP(S) through dedicated protocol handlers.

  3. Q: How is connection multiplexing managed? A: Using Yamux with configurable stream management and backpressure handling.

  4. Q: How does error handling work? A: Through a comprehensive error handling system using custom error types and proper error propagation.

The implementation provides a robust foundation for building secure and efficient network tunnels in Rust.

Build an ngrok Alternative with Rust and Yamux - David Viejo