From 1951b063d7ec6d6e8db8a0b5074c73f887749208 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Dec 2025 22:18:04 +0800 Subject: initial commit --- packet-detector/Cargo.toml | 39 +++++ packet-detector/src/bin/tls_client.rs | 51 ++++++ packet-detector/src/bin/tls_server.rs | 295 ++++++++++++++++++++++++++++++++++ packet-detector/src/lib.rs | 6 + packet-detector/src/main.rs | 150 +++++++++++++++++ packet-detector/src/tls_util.rs | 73 +++++++++ packet-detector/src/validator.rs | 64 ++++++++ 7 files changed, 678 insertions(+) create mode 100644 packet-detector/Cargo.toml create mode 100644 packet-detector/src/bin/tls_client.rs create mode 100644 packet-detector/src/bin/tls_server.rs create mode 100644 packet-detector/src/lib.rs create mode 100644 packet-detector/src/main.rs create mode 100644 packet-detector/src/tls_util.rs create mode 100644 packet-detector/src/validator.rs (limited to 'packet-detector') diff --git a/packet-detector/Cargo.toml b/packet-detector/Cargo.toml new file mode 100644 index 0000000..a6c699f --- /dev/null +++ b/packet-detector/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "packet-detector" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "packet-detector" +path = "src/main.rs" + +[[bin]] +name = "tls_server" +path = "src/bin/tls_server.rs" + +[[bin]] +name = "tls_client" +path = "src/bin/tls_client.rs" + +[dependencies] +anyhow = "1" +aya = { workspace = true } +chrono = { version = "0.4", features = ["serde"] } +env_logger = "0.11" +futures = "0.3" +hex = "0.4" +k8s-openapi = { version = "0.26", features = ["v1_31"] } +kube = { version = "2.0", features = ["runtime", "client", "derive", "rustls-tls"] } +log = "0.4" +rcgen = { version = "0.14", features = ["pem", "x509-parser"] } +rustls = { version = "0.23", features = ["tls12", "ring"] } +rustls-pemfile = "2.2" +rustls-webpki = { version = "0.102", features = ["ring"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +sha2 = "0.10" +tls-parser = "0.12" +tokio = { version = "1", features = ["full"] } +tokio-rustls = "0.26" +webpki-roots = "1" +x509-parser = "0.18" diff --git a/packet-detector/src/bin/tls_client.rs b/packet-detector/src/bin/tls_client.rs new file mode 100644 index 0000000..6098172 --- /dev/null +++ b/packet-detector/src/bin/tls_client.rs @@ -0,0 +1,51 @@ +//! mTLS test client + +use std::sync::Arc; +use packet_detector::tls_util::LoggingVerifier; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use rustls::version::TLS12; +use rustls::ClientConfig; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio_rustls::TlsConnector; + +const CERT: &str = "client_cert.pem"; +const KEY: &str = "client_key.pem"; + +fn load_creds() -> Result<(Vec>, PrivateKeyDer<'static>), Box> { + let cert_pem = std::fs::read_to_string(CERT)?; + let key_pem = std::fs::read_to_string(KEY)?; + let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes()).collect::, _>>()?; + let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())?.ok_or("No key")?; + Ok((certs, key)) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + rustls::crypto::ring::default_provider().install_default().ok(); + + let host = std::env::args().nth(1).unwrap_or_else(|| "127.0.0.1".into()); + let port: u16 = std::env::args().nth(2).and_then(|s| s.parse().ok()).unwrap_or(8443); + + let (certs, key) = load_creds()?; + println!("Connecting to {}:{} with client cert", host, port); + + let config = ClientConfig::builder_with_protocol_versions(&[&TLS12]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(LoggingVerifier)) + .with_client_auth_cert(certs, key)?; + + let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + let mut tls = TlsConnector::from(Arc::new(config)) + .connect(ServerName::try_from(host.clone())?, stream).await?; + println!("TLS handshake complete!"); + + let req = format!("GET / HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", host); + tls.write_all(req.as_bytes()).await?; + + let mut resp = Vec::new(); + tls.read_to_end(&mut resp).await?; + println!("\n{}", String::from_utf8_lossy(&resp)); + + Ok(()) +} diff --git a/packet-detector/src/bin/tls_server.rs b/packet-detector/src/bin/tls_server.rs new file mode 100644 index 0000000..7f68062 --- /dev/null +++ b/packet-detector/src/bin/tls_server.rs @@ -0,0 +1,295 @@ +//! TLS 1.2 Test Server with Client Tracking (mTLS Enabled) +//! +//! A mutual TLS (mTLS) HTTPS server for testing the eBPF TLS certificate validator. +//! Requires client certificates signed by the CA for authentication. +//! Tracks connected clients and displays their status. + +use std::collections::HashMap; +use std::env; +use std::fs; +use std::net::SocketAddr; +use std::path::Path; +use std::sync::Arc; + +use packet_detector::tls_util::{dn, parse_pem}; +use rcgen::{BasicConstraints, CertificateParams, IsCa, Issuer, KeyPair, KeyUsagePurpose, SanType}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::{ServerConfig, WebPkiClientVerifier}; +use rustls::version::TLS12; +use rustls::RootCertStore; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::RwLock; +use tokio_rustls::TlsAcceptor; + +const DEFAULT_PORT: u16 = 8443; +const CA_CERT_PATH: &str = "ca_cert.pem"; +const CA_KEY_PATH: &str = "ca_key.pem"; +const SERVER_CERT_PATH: &str = "server_cert.pem"; +const SERVER_KEY_PATH: &str = "server_key.pem"; +const CLIENT_CERT_PATH: &str = "client_cert.pem"; +const CLIENT_KEY_PATH: &str = "client_key.pem"; + +#[derive(Clone, Debug, serde::Serialize)] +struct Client { + ip: String, + connected_at: chrono::DateTime, + #[serde(skip)] + last_seen: chrono::DateTime, + requests: u64, +} + +#[derive(Default)] +struct State { + clients: HashMap, + connections: u64, + requests: u64, +} + +/// Generate a CA certificate for signing server and client certificates +fn generate_ca_certificate() -> Result<(String, String, KeyPair), Box> { + println!("Generating CA certificate..."); + + let mut params = CertificateParams::default(); + params.distinguished_name = dn("eBPF Test CA", "Zero Trust Network"); + params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + params.key_usages = vec![ + KeyUsagePurpose::KeyCertSign, + KeyUsagePurpose::CrlSign, + ]; + + let key_pair = KeyPair::generate()?; + let cert = params.self_signed(&key_pair)?; + + Ok((cert.pem(), key_pair.serialize_pem(), key_pair)) +} + +/// Generate a server certificate signed by the CA +fn generate_server_certificate( + ca_cert_pem: &str, + ca_key_pem: &str, +) -> Result<(String, String), Box> { + println!("Generating server certificate signed by CA..."); + + let ca_key = KeyPair::from_pem(ca_key_pem)?; + let issuer = Issuer::from_ca_cert_pem(ca_cert_pem, ca_key)?; + + let mut params = CertificateParams::default(); + params.distinguished_name = dn("localhost", "eBPF Test Server"); + params.subject_alt_names = vec![ + SanType::DnsName("localhost".try_into()?), + SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))), + ]; + params.key_usages = vec![ + KeyUsagePurpose::DigitalSignature, + KeyUsagePurpose::KeyEncipherment, + ]; + + let key_pair = KeyPair::generate()?; + let cert = params.signed_by(&key_pair, &issuer)?; + + Ok((cert.pem(), key_pair.serialize_pem())) +} + +/// Generate a client certificate signed by the CA +fn generate_client_certificate( + client_name: &str, + ca_cert_pem: &str, + ca_key_pem: &str, +) -> Result<(String, String), Box> { + println!("Generating client certificate for: {}", client_name); + + let ca_key = KeyPair::from_pem(ca_key_pem)?; + let issuer = Issuer::from_ca_cert_pem(ca_cert_pem, ca_key)?; + + let mut params = CertificateParams::default(); + params.distinguished_name = dn(client_name, "eBPF Test Client"); + params.key_usages = vec![ + KeyUsagePurpose::DigitalSignature, + ]; + + let key_pair = KeyPair::generate()?; + let cert = params.signed_by(&key_pair, &issuer)?; + + Ok((cert.pem(), key_pair.serialize_pem())) +} + +/// PKI setup result +struct PkiSetup { + ca_cert_pem: String, + server_cert: Vec>, + server_key: PrivateKeyDer<'static>, +} + +/// Load or generate the full PKI (CA, server cert, client cert) +fn setup_pki() -> Result> { + let ca_cert_pem: String; + let ca_key_pem: String; + + // Load or generate CA + if Path::new(CA_CERT_PATH).exists() && Path::new(CA_KEY_PATH).exists() { + println!("Loading existing CA from {} and {}", CA_CERT_PATH, CA_KEY_PATH); + ca_cert_pem = fs::read_to_string(CA_CERT_PATH)?; + ca_key_pem = fs::read_to_string(CA_KEY_PATH)?; + } else { + println!("Generating new PKI infrastructure..."); + let (cert, key, _) = generate_ca_certificate()?; + fs::write(CA_CERT_PATH, &cert)?; + fs::write(CA_KEY_PATH, &key)?; + println!("Saved CA to {} and {}", CA_CERT_PATH, CA_KEY_PATH); + ca_cert_pem = cert; + ca_key_pem = key; + } + + // Load or generate server certificate + let server_cert_pem: String; + let server_key_pem: String; + + if Path::new(SERVER_CERT_PATH).exists() && Path::new(SERVER_KEY_PATH).exists() { + println!("Loading existing server certificate..."); + server_cert_pem = fs::read_to_string(SERVER_CERT_PATH)?; + server_key_pem = fs::read_to_string(SERVER_KEY_PATH)?; + } else { + let (cert, key) = generate_server_certificate(&ca_cert_pem, &ca_key_pem)?; + fs::write(SERVER_CERT_PATH, &cert)?; + fs::write(SERVER_KEY_PATH, &key)?; + println!("Saved server cert to {} and {}", SERVER_CERT_PATH, SERVER_KEY_PATH); + server_cert_pem = cert; + server_key_pem = key; + } + + // Load or generate client certificate + if !Path::new(CLIENT_CERT_PATH).exists() || !Path::new(CLIENT_KEY_PATH).exists() { + let (cert, key) = generate_client_certificate("test-client", &ca_cert_pem, &ca_key_pem)?; + fs::write(CLIENT_CERT_PATH, &cert)?; + fs::write(CLIENT_KEY_PATH, &key)?; + println!("Saved client cert to {} and {}", CLIENT_CERT_PATH, CLIENT_KEY_PATH); + } + + // Parse certificates + let server_certs = parse_pem(&server_cert_pem)?; + + let server_key = rustls_pemfile::private_key(&mut server_key_pem.as_bytes())? + .ok_or("No server private key found")?; + + Ok(PkiSetup { + ca_cert_pem, + server_cert: server_certs, + server_key, + }) +} + +fn parse_request(data: &[u8]) -> (&str, Option) { + let req = std::str::from_utf8(data).unwrap_or(""); + let first_line = req.lines().next().unwrap_or(""); + + // Check X-Client-Name header + let client = req.lines() + .find(|l| l.to_lowercase().starts_with("x-client-name:")) + .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) + .or_else(|| { + // Check ?client= query param + first_line.find("client=").map(|i| { + let rest = &first_line[i + 7..]; + rest[..rest.find(|c| c == '&' || c == ' ').unwrap_or(rest.len())].to_string() + }) + }); + + (first_line, client) +} + +async fn handle_connection( + mut stream: tokio_rustls::server::TlsStream, + peer: SocketAddr, + state: Arc>, +) { + state.write().await.connections += 1; + + let mut buf = vec![0u8; 4096]; + let n = match stream.read(&mut buf).await { + Ok(0) => return, + Ok(n) => n, + Err(_) => return, + }; + + let (path, client_name) = parse_request(&buf[..n]); + + // Track client if name provided + if let Some(name) = &client_name { + let mut s = state.write().await; + s.requests += 1; + let now = chrono::Utc::now(); + s.clients.entry(name.clone()) + .and_modify(|c| { c.last_seen = now; c.requests += 1; }) + .or_insert(Client { + ip: peer.ip().to_string(), + connected_at: now, + last_seen: now, + requests: 1, + }); + } + + // Route request + let (ctype, body) = if path.contains("/status") || path.contains("/clients") { + let s = state.read().await; + ("application/json", serde_json::json!({ + "clients": s.clients.len(), + "connections": s.connections, + "requests": s.requests, + "list": &s.clients + }).to_string()) + } else if path.contains("/register") { + ("application/json", serde_json::json!({ + "status": "ok", + "client": client_name.as_deref().unwrap_or("unknown") + }).to_string()) + } else { + ("text/plain", format!("TLS Server OK - {} clients", state.read().await.clients.len())) + }; + + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + ctype, body.len(), body + ); + let _ = stream.write_all(resp.as_bytes()).await; + let _ = stream.shutdown().await; +} + + +#[tokio::main] +async fn main() -> Result<(), Box> { + rustls::crypto::ring::default_provider().install_default().ok(); + + let port: u16 = env::args().nth(1).and_then(|s| s.parse().ok()).unwrap_or(DEFAULT_PORT); + + let pki = setup_pki()?; + + // Build root store with CA + let mut root_store = RootCertStore::empty(); + for cert in parse_pem(&pki.ca_cert_pem)? { + root_store.add(cert)?; + } + + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + let config = ServerConfig::builder_with_protocol_versions(&[&TLS12]) + .with_client_cert_verifier(verifier) + .with_single_cert(pki.server_cert, pki.server_key)?; + + let acceptor = TlsAcceptor::from(Arc::new(config)); + let listener = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], port))).await?; + let state = Arc::new(RwLock::new(State::default())); + + println!("mTLS server on :{} (endpoints: /, /status, /register)", port); + println!("Test: curl -k --tlsv1.2 --cert {} --key {} https://127.0.0.1:{}", CLIENT_CERT_PATH, CLIENT_KEY_PATH, port); + + loop { + let Ok((stream, peer)) = listener.accept().await else { continue }; + let acceptor = acceptor.clone(); + let state = state.clone(); + tokio::spawn(async move { + if let Ok(tls) = acceptor.accept(stream).await { + handle_connection(tls, peer, state).await; + } + }); + } +} diff --git a/packet-detector/src/lib.rs b/packet-detector/src/lib.rs new file mode 100644 index 0000000..a3a8ac1 --- /dev/null +++ b/packet-detector/src/lib.rs @@ -0,0 +1,6 @@ +//! Packet Detector Library +//! +//! Shared utilities for TLS certificate validation. + +pub mod tls_util; +pub mod validator; diff --git a/packet-detector/src/main.rs b/packet-detector/src/main.rs new file mode 100644 index 0000000..69cccec --- /dev/null +++ b/packet-detector/src/main.rs @@ -0,0 +1,150 @@ +//! TLS Certificate Validator / UDP Magic Detector - eBPF-based + +use std::collections::HashSet; +use std::mem::size_of; +use std::net::Ipv4Addr; + +use anyhow::{Context, Result}; +use aya::maps::{HashMap as AyaHashMap, RingBuf}; +use aya::programs::{Xdp, XdpFlags}; +use aya::{include_bytes_aligned, Bpf}; +use log::{info, warn}; +use tls_parser::{parse_tls_plaintext, TlsMessage, TlsMessageHandshake}; +use tokio::signal; + +use packet_detector::validator::CertValidator; + +// this has to be the exact same as the struct in kernelspace +#[repr(C)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +struct ConnKey { + port_lo: u16, + port_hi: u16, +} + +unsafe impl aya::Pod for ConnKey {} + +fn make_conn_key(src_port: u16, dst_port: u16) -> ConnKey { + if src_port < dst_port { + ConnKey { port_lo: src_port, port_hi: dst_port } + } else { + ConnKey { port_lo: dst_port, port_hi: src_port } + } +} + +// this has to be the exact same as the struct in kernelspace +#[repr(C)] +#[derive(Clone, Copy)] +struct Event { + src_ip: u32, + dst_ip: u32, + src_port: u16, + dst_port: u16, + tls_len: u16, + _pad: u16, +} + +unsafe impl aya::Pod for Event {} + +const EVENT_SIZE: usize = size_of::(); + +fn ip(n: u32) -> Ipv4Addr { + Ipv4Addr::from(n.to_be_bytes()) +} + +fn extract_certs(tls_data: &[u8]) -> Option>> { + if tls_data.len() < 6 || tls_data[0] != 0x16 || tls_data[5] != 0x0B { return None; } + let (_, rec) = parse_tls_plaintext(tls_data).ok()?; + for msg in &rec.msg { + if let TlsMessage::Handshake(TlsMessageHandshake::Certificate(c)) = msg { + return Some(c.cert_chain.iter().map(|x| x.data.to_vec()).collect()); + } + } + None +} + +enum Decision { + Allow(ConnKey), + Block(ConnKey), + Skip, +} + +fn handle_event(data: &[u8], validator: Option<&CertValidator>) -> Decision { + if data.len() < EVENT_SIZE { return Decision::Skip; } + let ev: Event = unsafe { std::ptr::read(data.as_ptr() as *const _) }; + let addr = format!("{}:{} -> {}:{}", ip(ev.src_ip), ev.src_port, ip(ev.dst_ip), ev.dst_port); + let conn_key = make_conn_key(ev.src_port, ev.dst_port); + + if ev.tls_len == 0 { + info!("UDP magic from {}", addr); + return Decision::Allow(conn_key); + } + + let Some(v) = validator else { return Decision::Skip }; + let end = EVENT_SIZE + ev.tls_len as usize; + if end > data.len() { return Decision::Skip; } + let Some(certs) = extract_certs(&data[EVENT_SIZE..end]) else { return Decision::Skip }; + let result = v.validate(&certs); + info!("{}: {}", addr, result.subject); + + if result.valid { + info!("ALLOW conn {}:{} (signed by {})", conn_key.port_lo, conn_key.port_hi, result.issuer); + Decision::Allow(conn_key) + } else { + warn!("BLOCK conn {}:{} - {}", conn_key.port_lo, conn_key.port_hi, result.error.unwrap_or_default()); + Decision::Block(conn_key) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + rustls::crypto::ring::default_provider().install_default().ok(); + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: {} [ca-cert.pem]", args[0]); + std::process::exit(1); + } + + let iface = &args[1]; + let validator = args.get(2).map(|p| CertValidator::with_ca_file(p)).transpose()?; + info!("Mode: {}", if validator.is_some() { "TLS cert validation" } else { "UDP magic detection" }); + + let mut bpf = Bpf::load(include_bytes_aligned!("../../target/bpfel-unknown-none/release/packet-detector"))?; + let program: &mut Xdp = bpf.program_mut("packet_detector").unwrap().try_into()?; + program.load()?; + program.attach(iface, XdpFlags::default()).context("XDP attach failed")?; + info!("XDP attached to {}", iface); + + let mut allowed: AyaHashMap<_, ConnKey, u8> = AyaHashMap::try_from(bpf.take_map("ALLOWED_CONNS").unwrap())?; + let mut blocked: AyaHashMap<_, ConnKey, u8> = AyaHashMap::try_from(bpf.take_map("BLOCKED_CONNS").unwrap())?; + let mut ring: RingBuf<_> = RingBuf::try_from(bpf.take_map("TLS_EVENTS").unwrap())?; + let mut allowed_count = 0u32; + let mut blocked_count = 0u32; + + println!("\nRunning on {} - Ctrl+C to stop\n", iface); + + loop { + tokio::select! { + _ = signal::ctrl_c() => break, + _ = tokio::time::sleep(tokio::time::Duration::from_millis(10)) => { + while let Some(item) = ring.next() { + match handle_event(item.as_ref(), validator.as_ref()) { + Decision::Allow(key) => { + allowed.insert(key, 1, 0)?; + allowed_count += 1; + } + Decision::Block(key) => { + blocked.insert(key, 1, 0)?; + blocked_count += 1; + } + Decision::Skip => {} + } + } + } + } + } + println!("\nAllowed: {}, Blocked: {}", allowed_count, blocked_count); + Ok(()) +} diff --git a/packet-detector/src/tls_util.rs b/packet-detector/src/tls_util.rs new file mode 100644 index 0000000..456991b --- /dev/null +++ b/packet-detector/src/tls_util.rs @@ -0,0 +1,73 @@ +use rcgen::{DistinguishedName, DnType}; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; +use rustls::{DigitallySignedStruct, Error, SignatureScheme}; +use sha2::{Digest, Sha256}; +use x509_parser::prelude::*; + +/// SHA256 fingerprint (truncated to 16 bytes by default) +pub fn fingerprint(cert: &CertificateDer<'_>, full: bool) -> String { + let hash = Sha256::digest(cert.as_ref()); + if full { hex::encode(hash) } else { hex::encode(&hash[..16]) } +} + +pub fn dn(cn: &str, org: &str) -> DistinguishedName { + let mut d = DistinguishedName::new(); + d.push(DnType::CommonName, cn); + d.push(DnType::OrganizationName, org); + d.push(DnType::CountryName, "US"); + d +} + +/// Parse certs from PEM +pub fn parse_pem(pem: &str) -> Result>, std::io::Error> { + rustls_pemfile::certs(&mut pem.as_bytes()).collect() +} + +fn schemes() -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() +} + +/// Macro to implement the boilerplate verifier methods +macro_rules! impl_verifier_base { + () => { + fn verify_tls12_signature(&self, _: &[u8], _: &CertificateDer<'_>, _: &DigitallySignedStruct) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + fn verify_tls13_signature(&self, _: &[u8], _: &CertificateDer<'_>, _: &DigitallySignedStruct) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + fn supported_verify_schemes(&self) -> Vec { + schemes() + } + }; +} + +/// Accepts all certificates, logs info +#[derive(Debug)] +pub struct LoggingVerifier; + +impl ServerCertVerifier for LoggingVerifier { + fn verify_server_cert(&self, cert: &CertificateDer<'_>, intermediates: &[CertificateDer<'_>], _: &ServerName<'_>, _: &[u8], _: UnixTime) -> Result { + println!("\n=== Server Certificate ==="); + match X509Certificate::from_der(cert.as_ref()) { + Ok((_, x)) => { + println!("Subject: {}", x.subject()); + println!("Issuer: {}", x.issuer()); + println!("SHA256: {}", fingerprint(cert, true)); + if x.subject() == x.issuer() { println!("Type: Self-Signed"); } + } + Err(e) => println!("Parse failed: {}", e), + } + for (i, c) in intermediates.iter().enumerate() { + if let Ok((_, x)) = X509Certificate::from_der(c.as_ref()) { + println!("Intermediate #{}: {}", i + 1, x.subject()); + } + } + println!("Chain length: {}\n", 1 + intermediates.len()); + Ok(ServerCertVerified::assertion()) + } + impl_verifier_base!(); +} diff --git a/packet-detector/src/validator.rs b/packet-detector/src/validator.rs new file mode 100644 index 0000000..92e64d7 --- /dev/null +++ b/packet-detector/src/validator.rs @@ -0,0 +1,64 @@ +//! Certificate chain validation - signature only + +use anyhow::{anyhow, Result}; +use rustls::pki_types::CertificateDer; +use x509_parser::prelude::*; + +pub struct ValidationResult { + pub valid: bool, + pub subject: String, + pub issuer: String, + pub error: Option, +} + +impl ValidationResult { + fn fail(subject: String, issuer: String, err: impl ToString) -> Self { + Self { valid: false, subject, issuer, error: Some(err.to_string()) } + } +} + +pub struct CertValidator { + ca_der: Vec, +} + +impl CertValidator { + pub fn with_ca_file(path: &str) -> Result { + let pem = std::fs::read_to_string(path)?; + let der = rustls_pemfile::certs(&mut pem.as_bytes()) + .next() + .ok_or_else(|| anyhow!("No cert in PEM"))??; + Ok(Self { ca_der: der.to_vec() }) + } + + pub fn validate(&self, chain: &[Vec]) -> ValidationResult { + let Some(ee_der) = chain.first() else { + return ValidationResult::fail(String::new(), String::new(), "Empty chain"); + }; + + let (subject, issuer) = match X509Certificate::from_der(ee_der) { + Ok((_, c)) => (c.subject().to_string(), c.issuer().to_string()), + Err(e) => return ValidationResult::fail(String::new(), String::new(), format!("{e:?}")), + }; + + let ca = CertificateDer::from(self.ca_der.clone()); + let anchor = match webpki::anchor_from_trusted_cert(&ca) { + Ok(a) => a, + Err(e) => return ValidationResult::fail(subject, issuer, format!("CA: {e:?}")), + }; + + let cert = CertificateDer::from(ee_der.clone()); + let ee = match webpki::EndEntityCert::try_from(&cert) { + Ok(c) => c, + Err(e) => return ValidationResult::fail(subject, issuer, format!("{e:?}")), + }; + + let intermediates: Vec<_> = chain[1..].iter().map(|c| CertificateDer::from(c.clone())).collect(); + let algos = webpki::ALL_VERIFICATION_ALGS; + let time = webpki::types::UnixTime::since_unix_epoch(std::time::Duration::from_secs(4102444800)); // 2100 + + match ee.verify_for_usage(algos, &[anchor], &intermediates, time, webpki::KeyUsage::client_auth(), None, None) { + Ok(_) => ValidationResult { valid: true, subject, issuer, error: None }, + Err(e) => ValidationResult::fail(subject, issuer, format!("{e:?}")), + } + } +} -- cgit v1.2.3