How to Build an ngrok Alternative: Creating a Tunnel with Rust and Yamux

Creating Network Tunnels in Rust with Yamux Multiplexer

Technical Abstract

Network tunneling is a critical component in modern distributed systems, allowing for secure and efficient communication between network endpoints. This technical guide explores implementing a network tunnel in Rust using Yamux (Yet Another Multiplexer), focusing on the implementation from the provided codebase that enables multiple logical streams over a single connection. We'll examine how the codebase implements robust tunneling with TLS support, protocol handlers, and efficient stream management.

Prerequisites:

  • Rust 1.75.0 or later
  • Basic understanding of network programming
  • Familiarity with async Rust
  • Cargo package manager

Technical Highlights

  • Implementation of bidirectional multiplexed tunnels using Yamux
  • Async/await based network programming with Tokio
  • TLS support with SNI-based routing
  • Protocol handlers for TCP, TLS, and HTTP(S)
  • Comprehensive error handling and connection recovery

Core Content Structure

Technical Overview

Mermaid diagram

System Architecture

Mermaid diagram

Implementation Details

First, let's look at the project dependencies from the codebase:

[package]
name = "tunnel"
version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "=1.38", features = ["full"] }
log = "0.4.22"
yamux = "0.13.3"
tokio-util = { version = "0.7.12", features = ["compat", "codec"] }
futures = { version = "0.3.30", features = ["async-await"] }
tokio-yamux = "0.3.8"
anyhow = "1.0.89"


The core Yamux implementation in the codebase:

use std::sync::Arc;
use tokio::net::TcpStream;
use yamux::{Config, Connection, Mode};

pub struct YamuxListener<H> {
    address: String,
    protocol_handler: Arc<H>,
}

impl<H> YamuxListener<H>
where
    H: ProtocolHandler,
{
    pub fn new(address: String, protocol_handler: Arc<H>) -> Self {
        Self {
            address,
            protocol_handler,
        }
    }

    pub async fn run(&self) -> anyhow::Result<()> {
        let listener = TcpListener::bind(&self.address).await?;
        info!("Yamux Listener started on {}", self.address);

        loop {
            let (socket, addr) = listener.accept().await?;
            let protocol_handler = self.protocol_handler.clone();
            
            tokio::spawn(async move {
                let yamux = multiplexing::Yamux::upgrade_connection(
                    socket,
                    multiplexing::ConnectionDirection::Outbound,
                )?;
                
                let yamux_control = yamux.get_yamux_control().clone();
                let mut incoming = yamux.into_incoming();
                
                if let Err(e) = protocol_handler
                    .initialize(yamux_control.clone(), &mut incoming)
                    .await
                {
                    error!("Error during protocol initialization: {}", e);
                    return;
                }
            });
        }
    }
}

Component Breakdown

The codebase implements several key components:

  1. Yamux Connection Management:

src/multiplexing/yamux.rs

    /// Upgrade the underlying socket to use yamux
    pub fn upgrade_connection<TSocket>(
        socket: TSocket,
        direction: crate::multiplexing::direction::ConnectionDirection,
    ) -> io::Result<Self>
    where
        TSocket: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
    {
        let mode = match direction {
            crate::multiplexing::direction::ConnectionDirection::Inbound => Mode::Server,
            crate::multiplexing::direction::ConnectionDirection::Outbound => Mode::Client,
        };

        let config = yamux::Config::default();

        let substream_counter = AtomicRefCounter::new();
        let connection = yamux::Connection::new(socket.compat(), config, mode);
        let (control, incoming) =
            Self::spawn_incoming_stream_worker(connection, substream_counter.clone());

        Ok(Self {
            control,
            incoming,
            substream_counter,
        })
    }
  1. Protocol Handlers:

src/handlers.rs

// Individual handler for each TLS connection
pub struct TlsConnectionHandler {
    mux_control: Control,
    pub domains: Vec<String>,
}

impl TlsConnectionHandler {
    pub fn new(mux_control: Control, domains: Vec<String>) -> Self {
        Self {
            mux_control,
            domains,
        }
    }
}

#[async_trait]
impl TlsHandler for TlsConnectionHandler {
    fn domains(&self) -> &Vec<String> {
        &self.domains
    }

    fn needs_handshake(&self) -> bool {
        false
    }

    async fn handle_socket(
        &self,
        mut stream: Box<dyn AsyncReadWrite + Unpin + Send>,
        _domain: String,
    ) -> anyhow::Result<()> {
        // Handle the stream without TLS handshake
        let mut yamux_stream = self.mux_control.clone().open_stream().await?;
        let stream_id = StreamId::generate();
        let control_packet = ControlPacket::Stream(stream_id);
        yamux_stream.write_all(&control_packet.serialize()).await?;
        yamux_stream.flush().await?;

        tokio::spawn(async move {
            match tokio::io::copy_bidirectional(&mut stream, &mut yamux_stream).await {
                Ok(_) => {
                    debug!("Stream closed normally");
                }
                Err(ref e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
                    debug!("Connection reset by peer (os error 104), treating as normal closure");
                }
                Err(ref e) if e.kind() == std::io::ErrorKind::NotConnected => {
                    debug!("Connection reset by peer (os error 104), treating as normal closure");
                }
                Err(e) => {
                    error!("Stream error while handling tcp: {}", e);
                }
            }
        });

        Ok(())
    }

    async fn check_connection(&self) -> anyhow::Result<()> {
        Ok(()) // Default implementation
    }
}
  1. TLS Integration:

src/utils/tls.rs

pub fn load_tls_config(cert_path: &str, key_path: &str) -> anyhow::Result<ServerConfig> {
    let certs = load_certs(cert_path)?;
    let key = load_private_key(key_path)?;

    let mut config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(certs, key)?;
    config.alpn_protocols = vec![HTTP_1_1.to_vec(), HTTP_2.to_vec(), ACME_TLS_ALPN_NAME.to_vec()];
    Ok(config)
}

Security Considerations

The codebase implements several security features:

  1. TLS Configuration:
pub fn load_tls_config(cert_path: &str, key_path: &str) -> anyhow::Result<ServerConfig> {
    let certs = load_certs(cert_path)?;
    let key = load_private_key(key_path)?;

    let mut config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(certs, key)?;
    config.alpn_protocols = vec![HTTP_1_1.to_vec(), HTTP_2.to_vec(), ACME_TLS_ALPN_NAME.to_vec()];
    Ok(config)
}
  1. SNI Support:

src/streams/sni.rs

pub struct SniStream {
    pub sni: String,
    buffer: Cursor<Vec<u8>>,
    pub stream: TcpStream,
}

impl SniStream {
    pub fn new(sni: String, buffer: Vec<u8>, stream: TcpStream) -> Self {
        SniStream {
            sni,
            buffer: Cursor::new(buffer),
            stream,
        }
    }
    pub fn get_stream(&mut self) -> &mut TcpStream {
        &mut self.stream
    }
}

Performance Optimization

The codebase includes several performance optimizations:

  1. Efficient Stream Management:

src/multiplexing/yamux.rs

    type Item = Substream;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match futures::ready!(Pin::new(&mut self.inner).poll_recv(cx)) {
            Some(stream) => Poll::Ready(Some(Substream {
                stream: stream.compat(),
                _counter_guard: self.substream_counter.new_guard(),
            })),
            None => Poll::Ready(None),
        }
    }
}
  1. Connection Pooling:

src/handlers.rs

        // Start the listener
        handler.start_listener().await?;
        let mut control = control.clone();
        // Send the assigned port to the client
        // Open a new Yamux stream to send the port
        let mut yamux_stream = control.open_stream().await?;
        // write ServerInfo
        let stream_id = StreamId::generate();
        let server_info =
            ControlPacket::ServerInfo(stream_id, VERSION.to_string(), REGION.to_string());
        let serialized = server_info.serialize();
        yamux_stream.write_all(serialized.as_slice()).await?;
        yamux_stream.flush().await?;
        info!("tcp: Sent server info");

        // wait for ServerInfoAck
        let server_info_ack = wait_for_packet(&mut yamux_stream).await?;
        match server_info_ack {
            ControlPacket::ServerInfoAck(_) => {
                info!("tcp: Received ServerInfoAck");
            }
            _ => {
                error!("Unexpected packet type");
                return Err(anyhow::anyhow!("Unexpected packet type"));
            }
        }

        let control_packet = ControlPacket::PortAssignment(stream_id, assigned_port);
        yamux_stream.write_all(&control_packet.serialize()).await?;
        yamux_stream.flush().await?;

        tokio::spawn(async move {
            info!("tcp: Listening for ping");
            loop {
                let mut buffer = [0; 1024];
                let n = yamux_stream.read(&mut buffer).await.unwrap();
                if n == 0 {
                    error!("Connection closed");
                    break;
                }
                let packet = ControlPacket::deserialize(&buffer[..n]).unwrap();
                match packet {
                    ControlPacket::Ping(_) => {
                        // reply with pong
                        let now_as_millis = std::time::SystemTime::now()
                            .duration_since(std::time::UNIX_EPOCH)
                            .unwrap()
                            .as_millis() as u64;
                        let pong_packet = ControlPacket::Pong(stream_id.clone(), now_as_millis);
                        let serialized = pong_packet.serialize();
                        info!("tcp: Sending pong {:?}", serialized);
                        yamux_stream.write_all(serialized.as_slice()).await.unwrap();
                        yamux_stream.flush().await.unwrap();
                    }
                    _ => {}
                }
            }
        });

Technical Documentation

Error Handling

The codebase implements comprehensive error handling:

src/multiplexing/error.rs

#[derive(Debug, Error, Clone)]
pub enum YamuxControlError {
    #[error("Yamux connection error: {0}")]
    ConnectionError(String),
    #[error("Yamux connection closed")]
    ConnectionClosed,
    #[error("Yamux request send error: {0}")]
    RequestSendError(String),
    #[error("Yamux request send error: {0}")]
    RequestCanceled(#[from] Canceled),
}

Testing Strategies

The codebase includes extensive testing:

src/tests.rs

#[tokio::test]
async fn test_client_reconnection_performance() {
    // println!("Running test");
    const NUM_ITERATIONS: usize = 100;
    println!("Running {} iterations", NUM_ITERATIONS);
    let mut connection_times = Vec::with_capacity(NUM_ITERATIONS);
    let domain = "test.localho.st";
    for i in 0..NUM_ITERATIONS {
        println!("Iteration {}", i);
        let start_time = std::time::Instant::now();
        // Create client config
        let config = TunnelClientConfig {
            protocol: TunnelProtocol::Tls,
            server_address: "eu.tun.kfs.es:12345".to_string(),
            // server_address: "127.0.0.1:12345".to_string(),
            local_address: "127.0.0.1:3000".to_string(),
            domains: vec![domain.to_string()],
        };
        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
        // Connect
        let (status, management, open_rx) = run_tunnel_client(config.clone(), shutdown_rx)
            .await
            .unwrap();

        // Record connection time
        let connection_time = start_time.elapsed();
        connection_times.push(connection_time);
        // Disconnect
        let _ = shutdown_tx.send(());
        // wait until open_rx is false
        while open_rx.borrow().clone() {
            tokio::time::sleep(Duration::from_millis(10)).await;
        }
        // Wait for complete shutdown
        tokio::time::sleep(Duration::from_millis(100)).await;

        println!(
            "Iteration {}: Connection time: {:?}",
            i + 1,
            connection_time
        );
    }

    // Calculate and print statistics
    let total_time: Duration = connection_times.iter().sum();
    let avg_time = total_time / NUM_ITERATIONS as u32;

    let max_time = connection_times.iter().max().unwrap();
    let min_time = connection_times.iter().min().unwrap();

    println!("\nPerformance Test Results:");
    println!("Total iterations: {}", NUM_ITERATIONS);
    println!("Average connection time: {:?}", avg_time);
    println!("Maximum connection time: {:?}", max_time);
    println!("Minimum connection time: {:?}", min_time);

    // Optional: Assert that the average connection time is within acceptable bounds
    assert!(
        avg_time < Duration::from_secs(1),
        "Average connection time too high"
    );
}

Technical Conclusion

The codebase demonstrates a production-ready implementation of network tunneling using Rust and Yamux, featuring:

  • Robust protocol handling for TCP, TLS, and HTTP(S)
  • Comprehensive error handling and recovery
  • Efficient stream multiplexing
  • TLS support with SNI routing
  • Extensive testing coverage

Future improvements could include:

  • Enhanced metrics collection
  • Additional protocol handlers
  • Advanced compression support
  • Distributed connection pooling

Technical FAQ

  1. Q: How does the codebase handle TLS connections? A: Through the TLS handler implementation with SNI support and ALPN protocols.

  2. Q: What protocols are supported? A: TCP, TLS, and HTTP(S) through dedicated protocol handlers.

  3. Q: How is connection multiplexing managed? A: Using Yamux with configurable stream management and backpressure handling.

  4. Q: How does error handling work? A: Through a comprehensive error handling system using custom error types and proper error propagation.

The implementation provides a robust foundation for building secure and efficient network tunnels in Rust.