//! TLS 1.2 Test Server with Client Tracking (mTLS Enabled) //! //! A mutual TLS (mTLS) HTTPS server for testing the eBPF TLS certificate validator. //! Requires client certificates signed by the CA for authentication. //! Tracks connected clients and displays their status. use std::collections::HashMap; use std::env; use std::fs; use std::net::SocketAddr; use std::path::Path; use std::sync::Arc; use packet_detector::tls_util::{dn, parse_pem}; use rcgen::{BasicConstraints, CertificateParams, IsCa, Issuer, KeyPair, KeyUsagePurpose, SanType}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::server::{ServerConfig, WebPkiClientVerifier}; use rustls::version::TLS12; use rustls::RootCertStore; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::sync::RwLock; use tokio_rustls::TlsAcceptor; const DEFAULT_PORT: u16 = 8443; const CA_CERT_PATH: &str = "ca_cert.pem"; const CA_KEY_PATH: &str = "ca_key.pem"; const SERVER_CERT_PATH: &str = "server_cert.pem"; const SERVER_KEY_PATH: &str = "server_key.pem"; const CLIENT_CERT_PATH: &str = "client_cert.pem"; const CLIENT_KEY_PATH: &str = "client_key.pem"; #[derive(Clone, Debug, serde::Serialize)] struct Client { ip: String, connected_at: chrono::DateTime, #[serde(skip)] last_seen: chrono::DateTime, requests: u64, } #[derive(Default)] struct State { clients: HashMap, connections: u64, requests: u64, } /// Generate a CA certificate for signing server and client certificates fn generate_ca_certificate() -> Result<(String, String, KeyPair), Box> { println!("Generating CA certificate..."); let mut params = CertificateParams::default(); params.distinguished_name = dn("eBPF Test CA", "Zero Trust Network"); params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); params.key_usages = vec![ KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign, ]; let key_pair = KeyPair::generate()?; let cert = params.self_signed(&key_pair)?; Ok((cert.pem(), key_pair.serialize_pem(), key_pair)) } /// Generate a server certificate signed by the CA fn generate_server_certificate( ca_cert_pem: &str, ca_key_pem: &str, ) -> Result<(String, String), Box> { println!("Generating server certificate signed by CA..."); let ca_key = KeyPair::from_pem(ca_key_pem)?; let issuer = Issuer::from_ca_cert_pem(ca_cert_pem, ca_key)?; let mut params = CertificateParams::default(); params.distinguished_name = dn("localhost", "eBPF Test Server"); params.subject_alt_names = vec![ SanType::DnsName("localhost".try_into()?), SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))), ]; params.key_usages = vec![ KeyUsagePurpose::DigitalSignature, KeyUsagePurpose::KeyEncipherment, ]; let key_pair = KeyPair::generate()?; let cert = params.signed_by(&key_pair, &issuer)?; Ok((cert.pem(), key_pair.serialize_pem())) } /// Generate a client certificate signed by the CA fn generate_client_certificate( client_name: &str, ca_cert_pem: &str, ca_key_pem: &str, ) -> Result<(String, String), Box> { println!("Generating client certificate for: {}", client_name); let ca_key = KeyPair::from_pem(ca_key_pem)?; let issuer = Issuer::from_ca_cert_pem(ca_cert_pem, ca_key)?; let mut params = CertificateParams::default(); params.distinguished_name = dn(client_name, "eBPF Test Client"); params.key_usages = vec![ KeyUsagePurpose::DigitalSignature, ]; let key_pair = KeyPair::generate()?; let cert = params.signed_by(&key_pair, &issuer)?; Ok((cert.pem(), key_pair.serialize_pem())) } /// PKI setup result struct PkiSetup { ca_cert_pem: String, server_cert: Vec>, server_key: PrivateKeyDer<'static>, } /// Load or generate the full PKI (CA, server cert, client cert) fn setup_pki() -> Result> { let ca_cert_pem: String; let ca_key_pem: String; // Load or generate CA if Path::new(CA_CERT_PATH).exists() && Path::new(CA_KEY_PATH).exists() { println!("Loading existing CA from {} and {}", CA_CERT_PATH, CA_KEY_PATH); ca_cert_pem = fs::read_to_string(CA_CERT_PATH)?; ca_key_pem = fs::read_to_string(CA_KEY_PATH)?; } else { println!("Generating new PKI infrastructure..."); let (cert, key, _) = generate_ca_certificate()?; fs::write(CA_CERT_PATH, &cert)?; fs::write(CA_KEY_PATH, &key)?; println!("Saved CA to {} and {}", CA_CERT_PATH, CA_KEY_PATH); ca_cert_pem = cert; ca_key_pem = key; } // Load or generate server certificate let server_cert_pem: String; let server_key_pem: String; if Path::new(SERVER_CERT_PATH).exists() && Path::new(SERVER_KEY_PATH).exists() { println!("Loading existing server certificate..."); server_cert_pem = fs::read_to_string(SERVER_CERT_PATH)?; server_key_pem = fs::read_to_string(SERVER_KEY_PATH)?; } else { let (cert, key) = generate_server_certificate(&ca_cert_pem, &ca_key_pem)?; fs::write(SERVER_CERT_PATH, &cert)?; fs::write(SERVER_KEY_PATH, &key)?; println!("Saved server cert to {} and {}", SERVER_CERT_PATH, SERVER_KEY_PATH); server_cert_pem = cert; server_key_pem = key; } // Load or generate client certificate if !Path::new(CLIENT_CERT_PATH).exists() || !Path::new(CLIENT_KEY_PATH).exists() { let (cert, key) = generate_client_certificate("test-client", &ca_cert_pem, &ca_key_pem)?; fs::write(CLIENT_CERT_PATH, &cert)?; fs::write(CLIENT_KEY_PATH, &key)?; println!("Saved client cert to {} and {}", CLIENT_CERT_PATH, CLIENT_KEY_PATH); } // Parse certificates let server_certs = parse_pem(&server_cert_pem)?; let server_key = rustls_pemfile::private_key(&mut server_key_pem.as_bytes())? .ok_or("No server private key found")?; Ok(PkiSetup { ca_cert_pem, server_cert: server_certs, server_key, }) } fn parse_request(data: &[u8]) -> (&str, Option) { let req = std::str::from_utf8(data).unwrap_or(""); let first_line = req.lines().next().unwrap_or(""); // Check X-Client-Name header let client = req.lines() .find(|l| l.to_lowercase().starts_with("x-client-name:")) .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) .or_else(|| { // Check ?client= query param first_line.find("client=").map(|i| { let rest = &first_line[i + 7..]; rest[..rest.find(|c| c == '&' || c == ' ').unwrap_or(rest.len())].to_string() }) }); (first_line, client) } async fn handle_connection( mut stream: tokio_rustls::server::TlsStream, peer: SocketAddr, state: Arc>, ) { state.write().await.connections += 1; let mut buf = vec![0u8; 4096]; let n = match stream.read(&mut buf).await { Ok(0) => return, Ok(n) => n, Err(_) => return, }; let (path, client_name) = parse_request(&buf[..n]); // Track client if name provided if let Some(name) = &client_name { let mut s = state.write().await; s.requests += 1; let now = chrono::Utc::now(); s.clients.entry(name.clone()) .and_modify(|c| { c.last_seen = now; c.requests += 1; }) .or_insert(Client { ip: peer.ip().to_string(), connected_at: now, last_seen: now, requests: 1, }); } // Route request let (ctype, body) = if path.contains("/status") || path.contains("/clients") { let s = state.read().await; ("application/json", serde_json::json!({ "clients": s.clients.len(), "connections": s.connections, "requests": s.requests, "list": &s.clients }).to_string()) } else if path.contains("/register") { ("application/json", serde_json::json!({ "status": "ok", "client": client_name.as_deref().unwrap_or("unknown") }).to_string()) } else { ("text/plain", format!("TLS Server OK - {} clients", state.read().await.clients.len())) }; let resp = format!( "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", ctype, body.len(), body ); let _ = stream.write_all(resp.as_bytes()).await; let _ = stream.shutdown().await; } #[tokio::main] async fn main() -> Result<(), Box> { rustls::crypto::ring::default_provider().install_default().ok(); let port: u16 = env::args().nth(1).and_then(|s| s.parse().ok()).unwrap_or(DEFAULT_PORT); let pki = setup_pki()?; // Build root store with CA let mut root_store = RootCertStore::empty(); for cert in parse_pem(&pki.ca_cert_pem)? { root_store.add(cert)?; } let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; let config = ServerConfig::builder_with_protocol_versions(&[&TLS12]) .with_client_cert_verifier(verifier) .with_single_cert(pki.server_cert, pki.server_key)?; let acceptor = TlsAcceptor::from(Arc::new(config)); let listener = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], port))).await?; let state = Arc::new(RwLock::new(State::default())); println!("mTLS server on :{} (endpoints: /, /status, /register)", port); println!("Test: curl -k --tlsv1.2 --cert {} --key {} https://127.0.0.1:{}", CLIENT_CERT_PATH, CLIENT_KEY_PATH, port); loop { let Ok((stream, peer)) = listener.accept().await else { continue }; let acceptor = acceptor.clone(); let state = state.clone(); tokio::spawn(async move { if let Ok(tls) = acceptor.accept(stream).await { handle_connection(tls, peer, state).await; } }); } }