rename: lumen → punktfunk, everywhere
ci / rust (push) Has been cancelled

Full project rename, decided 2026-06-10:
- Crates/binaries: punktfunk-core / punktfunk-host / punktfunk-client-rs.
- C ABI: punktfunk_* symbols, Punktfunk* types, include/punktfunk_core.h,
  PUNKTFUNK_FEATURE_QUIC guard (header regenerated; cbindgen renames updated, incl.
  PUNKTFUNK_BTN_*/PUNKTFUNK_AXIS_* wire constants).
- Protocol: punktfunk/1 — control-plane magic LMN1 → PKF1, nonce salt lmn1 → pkf1.
  WIRE BREAK: clients must be rebuilt from this revision.
- Env knobs: PUNKTFUNK_VIDEO_SOURCE / PUNKTFUNK_COMPOSITOR / PUNKTFUNK_ZEROCOPY / ….
- Host config dir: ~/.config/punktfunk (the box's dir was migrated in place — the
  persistent identity is unchanged, pinned fingerprints stay valid).
- Swift package: PunktfunkKit + PunktfunkCore.xcframework + PunktfunkConnection
  (Sources/PunktfunkClient app + tests renamed with it); build-xcframework.sh updated.
- scripts/: 60-punktfunk.rules, punktfunk-host.service; OpenAPI doc regenerated.

Also: scripts/headless/run-headless-kde.sh — full headless Plasma bringup. Root cause of
"desktop but no apps/settings" over the stream: plasmashell launched without
XDG_MENU_PREFIX=plasma-, so the launcher resolved a nonexistent applications.menu and
rendered an empty menu. The script sets the complete KDE session env (menu prefix,
KDE_FULL_SESSION, session version) and rebuilds ksycoca before starting plasmashell.

Gate: 97/97 tests, clippy -D warnings (both feature sets), fmt, C-ABI harness PASS,
zero lumen references left outside .git.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-10 13:11:59 +00:00
parent b8b23c8fb2
commit bfd64ce871
119 changed files with 1245 additions and 1185 deletions
+50
View File
@@ -0,0 +1,50 @@
[package]
name = "punktfunk-core"
description = "punktfunk shared protocol/transport/FEC core, exposed over a stable C ABI"
version.workspace = true
edition.workspace = true
rust-version.workspace = true
license.workspace = true
authors.workspace = true
repository.workspace = true
[lib]
name = "punktfunk_core"
# `lib` — so punktfunk-host / punktfunk-client-rs / tools link it as a normal Rust crate.
# `staticlib` — `libpunktfunk_core.a` for the C test harness and static embedding.
# `cdylib` — `libpunktfunk_core.{so,dylib}` for Swift/Kotlin clients via the C ABI.
crate-type = ["lib", "cdylib", "staticlib"]
[features]
default = []
# Control-plane QUIC (pairing, config, reverse audio). tokio is permitted ONLY here,
# never on the per-frame hot path. Off by default so the core stays runtime-free.
quic = ["dep:quinn", "dep:tokio", "dep:rustls", "dep:rcgen", "dep:rustls-pki-types", "dep:sha2"]
[dependencies]
reed-solomon-simd = "3.1" # GF(2^16) Leopard-RS, SIMD, O(n log n) — the wall-breaker (P2)
# Vendored fork of fec-rs: GF(2^8) classic RS with the *Cauchy* generator matrix
# (M[j][i] = inv[(m+i)^j]) — byte-identical to the `nanors` library Moonlight uses, so our
# parity is decodable by a stock Moonlight client. (reed-solomon-erasure is Vandermonde and is
# NOT interoperable.) See vendor/fec-rs/LICENSE (BSD-2-Clause).
fec-rs = { path = "vendor/fec-rs" }
aes-gcm = "0.10" # AES-128-GCM session crypto, matches GameStream
zerocopy = { version = "0.8", features = ["derive"] }
bytes = "1"
thiserror = "2"
tracing = { version = "0.1", default-features = false, features = ["std"] }
rand = "0.9"
zeroize = "1"
quinn = { version = "0.11", optional = true }
rustls = { version = "0.23", optional = true, default-features = false, features = ["ring", "std"] }
rcgen = { version = "0.13", optional = true, default-features = false, features = ["aws_lc_rs"] }
rustls-pki-types = { version = "1", optional = true }
sha2 = { version = "0.10", optional = true }
tokio = { version = "1", optional = true, features = ["rt-multi-thread", "net", "sync", "macros"] }
[dev-dependencies]
proptest = "1"
[build-dependencies]
cbindgen = "0.29"
+34
View File
@@ -0,0 +1,34 @@
//! Generate the C header (`include/punktfunk_core.h`) from the `extern "C"` surface.
//!
//! cbindgen failure is a warning, not a hard error, so the crate still builds in minimal
//! environments (e.g. a CI image without the full toolchain); the header is checked in.
use std::env;
use std::path::PathBuf;
fn main() {
println!("cargo:rerun-if-changed=src/abi.rs");
println!("cargo:rerun-if-changed=src/config.rs");
println!("cargo:rerun-if-changed=src/input.rs");
println!("cargo:rerun-if-changed=src/client.rs");
println!("cargo:rerun-if-changed=src/error.rs");
println!("cargo:rerun-if-changed=cbindgen.toml");
let crate_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR");
// Workspace-level include/ dir: crates/punktfunk-core/ -> ../../include/
let out = PathBuf::from(&crate_dir)
.join("..")
.join("..")
.join("include")
.join("punktfunk_core.h");
match cbindgen::generate(&crate_dir) {
Ok(bindings) => {
bindings.write_to_file(&out);
println!("cargo:warning=punktfunk-core: wrote {}", out.display());
}
Err(e) => {
println!("cargo:warning=punktfunk-core: cbindgen failed ({e}); header not regenerated");
}
}
}
+56
View File
@@ -0,0 +1,56 @@
language = "C"
pragma_once = true
include_guard = "PUNKTFUNK_CORE_H"
autogen_warning = "/* Generated by cbindgen from punktfunk-core. Do not edit by hand. */"
header = "/* punktfunk-core C ABI — see crates/punktfunk-core/src/abi.rs */"
style = "type"
cpp_compat = true
tab_width = 4
documentation = true
documentation_style = "c99"
[parse]
parse_deps = false
[export.rename]
"InputEvent" = "PunktfunkInputEvent"
"InputKind" = "PunktfunkInputKind"
# Gamepad wire constants: bare BTN_* names collide with <linux/input-event-codes.h> (at
# DIFFERENT values — last definition silently wins); prefix everything we export.
"BTN_DPAD_UP" = "PUNKTFUNK_BTN_DPAD_UP"
"BTN_DPAD_DOWN" = "PUNKTFUNK_BTN_DPAD_DOWN"
"BTN_DPAD_LEFT" = "PUNKTFUNK_BTN_DPAD_LEFT"
"BTN_DPAD_RIGHT" = "PUNKTFUNK_BTN_DPAD_RIGHT"
"BTN_START" = "PUNKTFUNK_BTN_START"
"BTN_BACK" = "PUNKTFUNK_BTN_BACK"
"BTN_LS_CLICK" = "PUNKTFUNK_BTN_LS_CLICK"
"BTN_RS_CLICK" = "PUNKTFUNK_BTN_RS_CLICK"
"BTN_LB" = "PUNKTFUNK_BTN_LB"
"BTN_RB" = "PUNKTFUNK_BTN_RB"
"BTN_GUIDE" = "PUNKTFUNK_BTN_GUIDE"
"BTN_A" = "PUNKTFUNK_BTN_A"
"BTN_B" = "PUNKTFUNK_BTN_B"
"BTN_X" = "PUNKTFUNK_BTN_X"
"BTN_Y" = "PUNKTFUNK_BTN_Y"
"AXIS_LS_X" = "PUNKTFUNK_AXIS_LS_X"
"AXIS_LS_Y" = "PUNKTFUNK_AXIS_LS_Y"
"AXIS_RS_X" = "PUNKTFUNK_AXIS_RS_X"
"AXIS_RS_Y" = "PUNKTFUNK_AXIS_RS_Y"
"AXIS_LT" = "PUNKTFUNK_AXIS_LT"
"AXIS_RT" = "PUNKTFUNK_AXIS_RT"
"AUDIO_MAGIC" = "PUNKTFUNK_AUDIO_MAGIC"
"RUMBLE_MAGIC" = "PUNKTFUNK_RUMBLE_MAGIC"
# QualifiedScreamingSnakeCase already qualifies each variant with the enum name
# (PunktfunkStatus::Ok -> PUNKTFUNK_STATUS_OK); do NOT also set prefix_with_name or it doubles.
[enum]
rename_variants = "QualifiedScreamingSnakeCase"
[fn]
sort_by = "None"
[struct]
derive_eq = false
[defines]
"feature = quic" = "PUNKTFUNK_FEATURE_QUIC"
+755
View File
@@ -0,0 +1,755 @@
//! The stable `extern "C"` surface. `cbindgen` turns this module into
//! `include/punktfunk_core.h` (see `build.rs`).
//!
//! ## Principles (plan §5)
//! - Opaque handles only: C sees `PunktfunkSession*`, never a Rust type's fields.
//! - All cross-boundary structs are `#[repr(C)]`; buffers are pointer + length.
//! - Explicit ownership: every handle from `*_new` / `*_pair` must be passed to
//! [`punktfunk_session_free`]. A [`PunktfunkFrame`]'s `data` is borrowed until the next
//! `poll`/`free` on that session — copy it out before then.
//! - Versioned: [`punktfunk_abi_version`] + `PunktfunkConfig::struct_size` for forward-compat.
//! - Panics never cross the boundary: every entry point is wrapped in `catch_unwind`.
use crate::config::{Config, FecConfig, FecScheme, ProtocolPhase, Role};
use crate::error::PunktfunkStatus;
use crate::input::InputEvent;
use crate::session::Session;
use crate::stats::Stats;
use crate::transport::{loopback_pair, Transport, UdpTransport};
use std::ffi::{c_void, CStr};
use std::os::raw::c_char;
use std::panic::AssertUnwindSafe;
use std::ptr;
/// Opaque session handle. Pointer-only from C.
pub struct PunktfunkSession {
inner: Session,
/// Keeps the most recently polled frame alive so [`PunktfunkFrame::data`] stays valid
/// until the next poll or free.
last_frame: Option<crate::session::Frame>,
input_cb: Option<(PunktfunkInputCb, *mut c_void)>,
}
/// Forward-compatible session configuration. The caller MUST set `struct_size` to
/// `sizeof(PunktfunkConfig)`; the core uses it to detect ABI skew.
#[repr(C)]
#[derive(Clone, Copy)]
pub struct PunktfunkConfig {
pub struct_size: u32,
/// 0 = host, 1 = client.
pub role: u32,
/// 1 = P1 (GameStream-compatible), 2 = P2 (`punktfunk/1`).
pub phase: u32,
/// 0 = GF(2⁸), 1 = GF(2¹⁶).
pub fec_scheme: u32,
pub fec_percent: u32,
pub max_data_per_block: u32,
pub shard_payload: u32,
/// Non-zero enables AES-128-GCM.
pub encrypt: u32,
pub key: [u8; 16],
pub salt: [u8; 4],
/// Test hook for the loopback transport; 0 in production.
pub loopback_drop_period: u32,
/// Largest encoded access unit the receiver will accept (bounds reassembler memory).
pub max_frame_bytes: u64,
}
impl PunktfunkConfig {
fn to_config(self) -> Result<Config, PunktfunkStatus> {
let role = match self.role {
0 => Role::Host,
1 => Role::Client,
_ => return Err(PunktfunkStatus::InvalidArg),
};
let phase = match self.phase {
1 => ProtocolPhase::P1GameStream,
2 => ProtocolPhase::P2Punktfunk,
_ => return Err(PunktfunkStatus::InvalidArg),
};
// Range-check before narrowing: a `300` fec_percent or `65600` block size must be
// rejected, not silently truncated to a valid-looking value.
let scheme = u8::try_from(self.fec_scheme)
.ok()
.and_then(FecScheme::from_u8)
.ok_or(PunktfunkStatus::InvalidArg)?;
let fec_percent =
u8::try_from(self.fec_percent).map_err(|_| PunktfunkStatus::InvalidArg)?;
let max_data_per_block =
u16::try_from(self.max_data_per_block).map_err(|_| PunktfunkStatus::InvalidArg)?;
let cfg = Config {
role,
phase,
fec: FecConfig {
scheme,
fec_percent,
max_data_per_block,
},
shard_payload: self.shard_payload as usize,
max_frame_bytes: self.max_frame_bytes as usize,
encrypt: self.encrypt != 0,
key: self.key,
salt: self.salt,
loopback_drop_period: self.loopback_drop_period,
};
cfg.validate().map_err(|e| e.status())?;
Ok(cfg)
}
}
/// Read a `PunktfunkConfig` from a caller pointer, enforcing the `struct_size` ABI-skew
/// guard *before* reading the whole struct: a caller compiled against a smaller (older)
/// layout is rejected rather than causing an out-of-bounds read.
///
/// # Safety
/// `cfg` must either be null or point to at least its own declared `struct_size` bytes.
unsafe fn config_from_ptr(cfg: *const PunktfunkConfig) -> Result<Config, PunktfunkStatus> {
if cfg.is_null() {
return Err(PunktfunkStatus::NullPointer);
}
// Read only the 4-byte size prefix first to bound the subsequent full read.
let declared = unsafe { std::ptr::addr_of!((*cfg).struct_size).read_unaligned() } as usize;
if declared < std::mem::size_of::<PunktfunkConfig>() {
return Err(PunktfunkStatus::InvalidArg);
}
unsafe { *cfg }.to_config()
}
/// A reassembled access unit. `data`/`len` borrow session-owned memory valid until the
/// next `punktfunk_client_poll_frame`/`punktfunk_session_free` on the same session.
#[repr(C)]
pub struct PunktfunkFrame {
pub data: *const u8,
pub len: usize,
pub frame_index: u32,
pub pts_ns: u64,
pub flags: u32,
}
/// Snapshot of session counters.
#[repr(C)]
#[derive(Clone, Copy, Default)]
pub struct PunktfunkStats {
pub frames_submitted: u64,
pub frames_completed: u64,
pub frames_dropped: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub packets_dropped: u64,
pub fec_recovered_shards: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
}
impl From<Stats> for PunktfunkStats {
fn from(s: Stats) -> Self {
PunktfunkStats {
frames_submitted: s.frames_submitted,
frames_completed: s.frames_completed,
frames_dropped: s.frames_dropped,
packets_sent: s.packets_sent,
packets_received: s.packets_received,
packets_dropped: s.packets_dropped,
fec_recovered_shards: s.fec_recovered_shards,
bytes_sent: s.bytes_sent,
bytes_received: s.bytes_received,
}
}
}
/// Host-side callback invoked for each input event drained by `punktfunk_host_poll_input`.
pub type PunktfunkInputCb = extern "C" fn(event: *const InputEvent, user: *mut c_void);
#[inline]
fn guard<F: FnOnce() -> PunktfunkStatus>(f: F) -> PunktfunkStatus {
std::panic::catch_unwind(AssertUnwindSafe(f)).unwrap_or(PunktfunkStatus::Panic)
}
fn new_handle(session: Session) -> *mut PunktfunkSession {
Box::into_raw(Box::new(PunktfunkSession {
inner: session,
last_frame: None,
input_cb: None,
}))
}
/// Current ABI version. Mismatch with [`crate::ABI_VERSION`] means incompatible core.
#[no_mangle]
pub extern "C" fn punktfunk_abi_version() -> u32 {
crate::ABI_VERSION
}
/// Create a session over a real UDP transport (`local`/`peer` are `host:port` strings).
/// Returns NULL on error.
///
/// # Safety
/// `cfg`, `local`, `peer` must be valid pointers; the strings must be NUL-terminated.
#[no_mangle]
pub unsafe extern "C" fn punktfunk_session_new(
cfg: *const PunktfunkConfig,
local: *const c_char,
peer: *const c_char,
) -> *mut PunktfunkSession {
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
if cfg.is_null() || local.is_null() || peer.is_null() {
return ptr::null_mut();
}
let config = match unsafe { config_from_ptr(cfg) } {
Ok(c) => c,
Err(_) => return ptr::null_mut(),
};
let local = match unsafe { CStr::from_ptr(local) }.to_str() {
Ok(s) => s,
Err(_) => return ptr::null_mut(),
};
let peer = match unsafe { CStr::from_ptr(peer) }.to_str() {
Ok(s) => s,
Err(_) => return ptr::null_mut(),
};
let transport: Box<dyn Transport> = match UdpTransport::connect(local, peer) {
Ok(t) => Box::new(t),
Err(_) => return ptr::null_mut(),
};
match Session::new(config, transport) {
Ok(s) => new_handle(s),
Err(_) => ptr::null_mut(),
}
}));
result.unwrap_or(ptr::null_mut())
}
/// Create a connected host+client session pair sharing an in-process loopback
/// transport. Test/dev only — exercises the full FEC + framing path without a network.
///
/// # Safety
/// All four pointers must be valid; the two out-params receive owned handles.
#[no_mangle]
pub unsafe extern "C" fn punktfunk_test_loopback_pair(
host_cfg: *const PunktfunkConfig,
client_cfg: *const PunktfunkConfig,
out_host: *mut *mut PunktfunkSession,
out_client: *mut *mut PunktfunkSession,
) -> PunktfunkStatus {
guard(|| {
if host_cfg.is_null() || client_cfg.is_null() || out_host.is_null() || out_client.is_null()
{
return PunktfunkStatus::NullPointer;
}
let hconf = match unsafe { config_from_ptr(host_cfg) } {
Ok(c) => c,
Err(s) => return s,
};
let cconf = match unsafe { config_from_ptr(client_cfg) } {
Ok(c) => c,
Err(s) => return s,
};
let (ht, ct) = loopback_pair(hconf.loopback_drop_period, cconf.loopback_drop_period);
let hs = match Session::new(hconf, Box::new(ht)) {
Ok(s) => s,
Err(e) => return e.status(),
};
let cs = match Session::new(cconf, Box::new(ct)) {
Ok(s) => s,
Err(e) => return e.status(),
};
unsafe {
*out_host = new_handle(hs);
*out_client = new_handle(cs);
}
PunktfunkStatus::Ok
})
}
/// Free a session handle. Safe to call with NULL.
///
/// # Safety
/// `s` must be a handle from `punktfunk_session_new`/`punktfunk_test_loopback_pair`, freed once.
#[no_mangle]
pub unsafe extern "C" fn punktfunk_session_free(s: *mut PunktfunkSession) {
if !s.is_null() {
drop(unsafe { Box::from_raw(s) });
}
}
/// Host: FEC-protect, packetize, seal and send one encoded access unit.
///
/// # Safety
/// `s` is a valid host handle; `data` points to `len` readable bytes (or `len == 0`).
#[no_mangle]
pub unsafe extern "C" fn punktfunk_host_submit_frame(
s: *mut PunktfunkSession,
data: *const u8,
len: usize,
pts_ns: u64,
flags: u32,
) -> PunktfunkStatus {
guard(|| {
let s = match unsafe { s.as_mut() } {
Some(s) => s,
None => return PunktfunkStatus::NullPointer,
};
if data.is_null() && len != 0 {
return PunktfunkStatus::NullPointer;
}
let slice = if len == 0 {
&[][..]
} else {
unsafe { std::slice::from_raw_parts(data, len) }
};
match s.inner.submit_frame(slice, pts_ns, flags) {
Ok(()) => PunktfunkStatus::Ok,
Err(e) => e.status(),
}
})
}
/// Client: poll for the next reassembled access unit. Returns [`PunktfunkStatus::NoFrame`]
/// when nothing is ready yet. On `Ok`, `*out` borrows session memory until the next poll.
///
/// # Safety
/// `s` is a valid client handle; `out` points to a writable `PunktfunkFrame`.
#[no_mangle]
pub unsafe extern "C" fn punktfunk_client_poll_frame(
s: *mut PunktfunkSession,
out: *mut PunktfunkFrame,
) -> PunktfunkStatus {
guard(|| {
let s = match unsafe { s.as_mut() } {
Some(s) => s,
None => return PunktfunkStatus::NullPointer,
};
if out.is_null() {
return PunktfunkStatus::NullPointer;
}
match s.inner.poll_frame() {
Ok(frame) => {
s.last_frame = Some(frame);
let f = s.last_frame.as_ref().unwrap();
unsafe {
*out = PunktfunkFrame {
data: f.data.as_ptr(),
len: f.data.len(),
frame_index: f.frame_index,
pts_ns: f.pts_ns,
flags: f.flags,
};
}
PunktfunkStatus::Ok
}
Err(e) => e.status(),
}
})
}
/// Client: serialize and send one input event to the host.
///
/// # Safety
/// `s` is a valid client handle; `ev` points to a valid [`InputEvent`].
#[no_mangle]
pub unsafe extern "C" fn punktfunk_send_input(
s: *mut PunktfunkSession,
ev: *const InputEvent,
) -> PunktfunkStatus {
guard(|| {
let s = match unsafe { s.as_mut() } {
Some(s) => s,
None => return PunktfunkStatus::NullPointer,
};
let ev = match unsafe { ev.as_ref() } {
Some(e) => e,
None => return PunktfunkStatus::NullPointer,
};
match s.inner.send_input(ev) {
Ok(()) => PunktfunkStatus::Ok,
Err(e) => e.status(),
}
})
}
/// Register the host-side input callback (pass a NULL fn pointer to clear). The callback
/// fires from within [`punktfunk_host_poll_input`], on the calling thread.
///
/// # Safety
/// `s` is a valid host handle; `user` is passed back verbatim to `cb`.
#[no_mangle]
pub unsafe extern "C" fn punktfunk_set_input_callback(
s: *mut PunktfunkSession,
// Written as an explicit `Option<fn>` (not the `PunktfunkInputCb` alias) so cbindgen
// emits a nullable C function pointer rather than an opaque wrapper struct.
cb: Option<extern "C" fn(event: *const InputEvent, user: *mut c_void)>,
user: *mut c_void,
) -> PunktfunkStatus {
guard(|| {
let s = match unsafe { s.as_mut() } {
Some(s) => s,
None => return PunktfunkStatus::NullPointer,
};
s.input_cb = cb.map(|c| (c, user));
PunktfunkStatus::Ok
})
}
/// Host: drain all pending input events, invoking the registered callback for each.
/// Returns the count dispatched (≥ 0), or a negative [`PunktfunkStatus`] on error.
///
/// # Safety
/// `s` is a valid host handle.
#[no_mangle]
pub unsafe extern "C" fn punktfunk_host_poll_input(s: *mut PunktfunkSession) -> i32 {
let r = std::panic::catch_unwind(AssertUnwindSafe(|| {
let s = match unsafe { s.as_mut() } {
Some(s) => s,
None => return PunktfunkStatus::NullPointer as i32,
};
let cb = s.input_cb;
let mut count = 0i32;
loop {
match s.inner.poll_input() {
Ok(Some(ev)) => {
if let Some((cb, user)) = cb {
cb(&ev as *const InputEvent, user);
}
count += 1;
}
Ok(None) => break,
Err(e) => return e.status() as i32,
}
}
count
}));
r.unwrap_or(PunktfunkStatus::Panic as i32)
}
/// Copy session counters into `*out`.
///
/// # Safety
/// `s` is a valid handle; `out` points to a writable `PunktfunkStats`.
#[no_mangle]
pub unsafe extern "C" fn punktfunk_get_stats(
s: *mut PunktfunkSession,
out: *mut PunktfunkStats,
) -> PunktfunkStatus {
guard(|| {
let s = match unsafe { s.as_ref() } {
Some(s) => s,
None => return PunktfunkStatus::NullPointer,
};
if out.is_null() {
return PunktfunkStatus::NullPointer;
}
let stats = s.inner.stats();
unsafe { *out = PunktfunkStats::from(stats) };
PunktfunkStatus::Ok
})
}
// ---------------------------------------------------------------------------------------------
// punktfunk/1 connection API (`quic` feature) — the embeddable client connector platform clients
// link (SwiftUI/VideoToolbox, Android, …). In the generated header these are guarded by
// `PUNKTFUNK_FEATURE_QUIC`; define it when linking a punktfunk-core built with `--features quic`.
// ---------------------------------------------------------------------------------------------
/// Opaque handle to a live `punktfunk/1` connection (QUIC control plane + UDP data plane, all
/// pumped on internal threads).
///
/// Thread contract: each plane (video `next_au`, audio `next_audio`, rumble `next_rumble`)
/// may be pulled from its own thread, at most one thread per plane. The accessors only
/// take shared references internally (per-plane mutexed borrow slots), so cross-plane
/// concurrency is sound — never two threads on the *same* plane.
#[cfg(feature = "quic")]
pub struct PunktfunkConnection {
inner: crate::client::NativeClient,
/// Backs the pointer returned by the last `punktfunk_connection_next_au` (borrow-until-next-call).
last: std::sync::Mutex<Option<crate::session::Frame>>,
/// Same, for `punktfunk_connection_next_audio` (independent of the video slot).
last_audio: std::sync::Mutex<Option<crate::client::AudioPacket>>,
}
/// Connect to a `punktfunk/1` host and start a session at `width`x`height`@`refresh_hz`.
/// Blocks up to `timeout_ms` for the handshake. Returns NULL on failure.
///
/// Trust: `pin_sha256` (NULL or 32 bytes) is the expected SHA-256 fingerprint of the host's
/// certificate — a mismatching host is rejected. NULL = trust on first use; persist the
/// fingerprint written to `observed_sha256_out` (NULL or 32 bytes, filled on success) and
/// pass it as the pin on every later connect.
///
/// # Safety
/// `host` is a NUL-terminated UTF-8 string (IP or hostname resolvable by the platform);
/// `pin_sha256`/`observed_sha256_out` are each NULL or valid for 32 bytes.
#[cfg(feature = "quic")]
#[no_mangle]
pub unsafe extern "C" fn punktfunk_connect(
host: *const std::os::raw::c_char,
port: u16,
width: u32,
height: u32,
refresh_hz: u32,
pin_sha256: *const u8,
observed_sha256_out: *mut u8,
timeout_ms: u32,
) -> *mut PunktfunkConnection {
let r = std::panic::catch_unwind(AssertUnwindSafe(|| {
if host.is_null() {
return std::ptr::null_mut();
}
let host = match unsafe { std::ffi::CStr::from_ptr(host) }.to_str() {
Ok(s) => s,
Err(_) => return std::ptr::null_mut(),
};
let mode = crate::config::Mode {
width,
height,
refresh_hz,
};
let pin = if pin_sha256.is_null() {
None
} else {
let mut p = [0u8; 32];
p.copy_from_slice(unsafe { std::slice::from_raw_parts(pin_sha256, 32) });
Some(p)
};
match crate::client::NativeClient::connect(
host,
port,
mode,
pin,
std::time::Duration::from_millis(timeout_ms as u64),
) {
Ok(c) => {
if !observed_sha256_out.is_null() {
unsafe {
std::slice::from_raw_parts_mut(observed_sha256_out, 32)
.copy_from_slice(&c.host_fingerprint);
}
}
Box::into_raw(Box::new(PunktfunkConnection {
inner: c,
last: std::sync::Mutex::new(None),
last_audio: std::sync::Mutex::new(None),
}))
}
Err(_) => std::ptr::null_mut(),
}
}));
r.unwrap_or(std::ptr::null_mut())
}
/// Pull the next reassembled access unit, waiting up to `timeout_ms`. Returns
/// [`PunktfunkStatus::NoFrame`] on timeout and [`PunktfunkStatus::Closed`] once the session ended.
/// On `Ok`, `*out` borrows connection memory **until the next `next_au` call** on this
/// handle (the audio/rumble planes do not invalidate it).
///
/// # Safety
/// `c` is a valid connection handle; `out` is writable. At most one thread pulls video —
/// it may run concurrently with one audio-pulling and one rumble-pulling thread.
#[cfg(feature = "quic")]
#[no_mangle]
pub unsafe extern "C" fn punktfunk_connection_next_au(
c: *mut PunktfunkConnection,
out: *mut PunktfunkFrame,
timeout_ms: u32,
) -> PunktfunkStatus {
guard(|| {
// Shared reference only: video and audio threads must never alias a `&mut`.
let c = match unsafe { c.as_ref() } {
Some(c) => c,
None => return PunktfunkStatus::NullPointer,
};
if out.is_null() {
return PunktfunkStatus::NullPointer;
}
match c
.inner
.next_frame(std::time::Duration::from_millis(timeout_ms as u64))
{
Ok(frame) => {
let mut slot = c.last.lock().unwrap();
*slot = Some(frame);
let f = slot.as_ref().unwrap();
unsafe {
*out = PunktfunkFrame {
data: f.data.as_ptr(),
len: f.data.len(),
frame_index: f.frame_index,
pts_ns: f.pts_ns,
flags: f.flags,
};
}
PunktfunkStatus::Ok
}
Err(e) => e.status(),
}
})
}
/// One Opus audio packet pulled off a `punktfunk/1` connection (48 kHz stereo, 5 ms frames).
/// `data` borrows connection memory until the next `punktfunk_connection_next_audio` call.
#[cfg(feature = "quic")]
#[repr(C)]
pub struct PunktfunkAudioPacket {
pub data: *const u8,
pub len: usize,
pub seq: u32,
pub pts_ns: u64,
}
/// Pull the next Opus audio packet, waiting up to `timeout_ms`. Returns
/// [`PunktfunkStatus::NoFrame`] on timeout and [`PunktfunkStatus::Closed`] once the session ended.
/// On `Ok`, `out->data` borrows connection memory **until the next audio call** on this
/// handle (independent of the video slot). Drain from a dedicated audio thread — packets
/// arrive every 5 ms and the internal queue holds 320 ms.
///
/// # Safety
/// `c` is a valid connection handle; `out` is writable. At most one thread pulls audio —
/// it may run concurrently with the video/rumble pullers.
#[cfg(feature = "quic")]
#[no_mangle]
pub unsafe extern "C" fn punktfunk_connection_next_audio(
c: *mut PunktfunkConnection,
out: *mut PunktfunkAudioPacket,
timeout_ms: u32,
) -> PunktfunkStatus {
guard(|| {
let c = match unsafe { c.as_ref() } {
Some(c) => c,
None => return PunktfunkStatus::NullPointer,
};
if out.is_null() {
return PunktfunkStatus::NullPointer;
}
match c
.inner
.next_audio(std::time::Duration::from_millis(timeout_ms as u64))
{
Ok(pkt) => {
let mut slot = c.last_audio.lock().unwrap();
*slot = Some(pkt);
let p = slot.as_ref().unwrap();
unsafe {
*out = PunktfunkAudioPacket {
data: p.data.as_ptr(),
len: p.data.len(),
seq: p.seq,
pts_ns: p.pts_ns,
};
}
PunktfunkStatus::Ok
}
Err(e) => e.status(),
}
})
}
/// Pull the next rumble (force-feedback) update, waiting up to `timeout_ms`. Amplitudes
/// are 0..0xFFFF (`low` = low-frequency motor, `high` = high-frequency), `(0, 0)` = stop.
/// Same timeout/closed semantics as [`punktfunk_connection_next_audio`].
///
/// # Safety
/// `c` is a valid connection handle; out pointers are writable (NULLs are skipped). At
/// most one thread pulls rumble — it may run concurrently with the video/audio pullers.
#[cfg(feature = "quic")]
#[no_mangle]
pub unsafe extern "C" fn punktfunk_connection_next_rumble(
c: *mut PunktfunkConnection,
pad: *mut u16,
low: *mut u16,
high: *mut u16,
timeout_ms: u32,
) -> PunktfunkStatus {
guard(|| {
let c = match unsafe { c.as_ref() } {
Some(c) => c,
None => return PunktfunkStatus::NullPointer,
};
match c
.inner
.next_rumble(std::time::Duration::from_millis(timeout_ms as u64))
{
Ok((p, l, h)) => {
unsafe {
if !pad.is_null() {
*pad = p;
}
if !low.is_null() {
*low = l;
}
if !high.is_null() {
*high = h;
}
}
PunktfunkStatus::Ok
}
Err(e) => e.status(),
}
})
}
/// Send one input event to the host as a QUIC datagram (non-blocking enqueue).
///
/// # Safety
/// `c` is a valid connection handle; `ev` points to a valid [`InputEvent`].
#[cfg(feature = "quic")]
#[no_mangle]
pub unsafe extern "C" fn punktfunk_connection_send_input(
c: *mut PunktfunkConnection,
ev: *const InputEvent,
) -> PunktfunkStatus {
guard(|| {
let c = match unsafe { c.as_ref() } {
Some(c) => c,
None => return PunktfunkStatus::NullPointer,
};
let ev = match unsafe { ev.as_ref() } {
Some(e) => e,
None => return PunktfunkStatus::NullPointer,
};
match c.inner.send_input(ev) {
Ok(()) => PunktfunkStatus::Ok,
Err(e) => e.status(),
}
})
}
/// The host-confirmed session mode (from the Welcome). Safe any time after connect.
///
/// # Safety
/// `c` is a valid connection handle; out pointers are writable (NULLs are skipped).
#[cfg(feature = "quic")]
#[no_mangle]
pub unsafe extern "C" fn punktfunk_connection_mode(
c: *const PunktfunkConnection,
width: *mut u32,
height: *mut u32,
refresh_hz: *mut u32,
) -> PunktfunkStatus {
guard(|| {
let c = match unsafe { c.as_ref() } {
Some(c) => c,
None => return PunktfunkStatus::NullPointer,
};
unsafe {
if !width.is_null() {
*width = c.inner.mode.width;
}
if !height.is_null() {
*height = c.inner.mode.height;
}
if !refresh_hz.is_null() {
*refresh_hz = c.inner.mode.refresh_hz;
}
}
PunktfunkStatus::Ok
})
}
/// Close the connection and free the handle (joins the internal threads). NULL is a no-op.
///
/// # Safety
/// `c` was returned by [`punktfunk_connect`] and is not used after this call.
#[cfg(feature = "quic")]
#[no_mangle]
pub unsafe extern "C" fn punktfunk_connection_close(c: *mut PunktfunkConnection) {
if !c.is_null() {
drop(unsafe { Box::from_raw(c) });
}
}
+341
View File
@@ -0,0 +1,341 @@
//! The embeddable `punktfunk/1` client connector (M4 groundwork), behind the `quic` feature.
//!
//! [`NativeClient::connect`] runs the full client side of the protocol — QUIC handshake
//! ([`crate::quic`]), UDP data plane ([`crate::session::Session`] on a native thread), input
//! datagrams — and hands the embedder a dead-simple surface: *pull reassembled access units,
//! push input events*. This is what the platform clients (SwiftUI/VideoToolbox, Android, …)
//! link via the C ABI (`punktfunk_connect` & co. in [`crate::abi`]); `punktfunk-client-rs` is the
//! Rust-native consumer of the same flow.
//!
//! Threading: one worker thread owns a tokio runtime (QUIC control plane only — design
//! invariant) plus a blocking data-plane pump; frames cross to the embedder over a bounded
//! channel. All methods are safe to call from any single embedder thread.
use crate::config::{Mode, Role};
use crate::error::{PunktfunkError, Result};
use crate::input::InputEvent;
use crate::quic::{endpoint, io, Hello, Start, Welcome};
use crate::session::{Frame, Session};
use crate::transport::UdpTransport;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{Receiver, RecvTimeoutError, SyncSender};
use std::sync::Arc;
use std::time::Duration;
/// Frames buffered between the data-plane pump and the embedder. Small: the embedder
/// (decoder) should drain at frame rate; when it falls behind, the newest frame is dropped
/// (display freshness over completeness — FEC/keyframes recover).
const FRAME_QUEUE: usize = 16;
/// Audio packets buffered for the embedder: 64 × 5 ms = 320 ms of slack. A lagging
/// embedder drops the newest packet (the audio renderer conceals the gap).
const AUDIO_QUEUE: usize = 64;
/// Rumble updates buffered for the embedder. Overflow drops the NEWEST update (same
/// `try_send` discipline as the other planes) — the host re-sends rumble state
/// periodically, so a dropped transition (including a stop) heals within ~500 ms.
const RUMBLE_QUEUE: usize = 16;
/// One Opus packet from the host's audio datagram stream (48 kHz stereo, 5 ms frames).
#[derive(Clone, Debug)]
pub struct AudioPacket {
pub seq: u32,
pub pts_ns: u64,
/// The raw Opus payload — feed it to an Opus decoder as one frame.
pub data: Vec<u8>,
}
pub struct NativeClient {
frames: Receiver<Frame>,
audio: Receiver<AudioPacket>,
rumble: Receiver<(u16, u16, u16)>,
input_tx: tokio::sync::mpsc::UnboundedSender<InputEvent>,
shutdown: Arc<AtomicBool>,
worker: Option<std::thread::JoinHandle<()>>,
/// The host-confirmed session mode (from the Welcome).
pub mode: Mode,
/// SHA-256 fingerprint of the certificate the host actually presented. A TOFU caller
/// (`pin = None`) persists this and passes it as the pin from then on.
pub host_fingerprint: [u8; 32],
}
impl NativeClient {
/// Connect to a `punktfunk/1` host and start the session at (up to) `mode`. Blocks until the
/// handshake completes or `timeout` elapses.
///
/// `pin`: expected SHA-256 of the host's certificate. `Some` and the host presents
/// anything else → the handshake is rejected ([`PunktfunkError::Crypto`]). `None` = trust on
/// first use; check [`NativeClient::host_fingerprint`] after connecting.
pub fn connect(
host: &str,
port: u16,
mode: Mode,
pin: Option<[u8; 32]>,
timeout: Duration,
) -> Result<NativeClient> {
let (frame_tx, frame_rx) = std::sync::mpsc::sync_channel::<Frame>(FRAME_QUEUE);
let (audio_tx, audio_rx) = std::sync::mpsc::sync_channel::<AudioPacket>(AUDIO_QUEUE);
let (rumble_tx, rumble_rx) = std::sync::mpsc::sync_channel::<(u16, u16, u16)>(RUMBLE_QUEUE);
let (input_tx, input_rx) = tokio::sync::mpsc::unbounded_channel::<InputEvent>();
let (ready_tx, ready_rx) = std::sync::mpsc::channel::<Result<(Mode, [u8; 32])>>();
let shutdown = Arc::new(AtomicBool::new(false));
let host = host.to_string();
let shutdown_w = shutdown.clone();
let worker = std::thread::Builder::new()
.name("punktfunk-client".into())
.spawn(move || {
let rt = match tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
let _ = ready_tx.send(Err(PunktfunkError::Io(e)));
return;
}
};
rt.block_on(worker_main(WorkerArgs {
host,
port,
mode,
pin,
frame_tx,
audio_tx,
rumble_tx,
input_rx,
ready_tx,
shutdown: shutdown_w,
}));
})
.map_err(PunktfunkError::Io)?;
let (negotiated, fingerprint) = match ready_rx.recv_timeout(timeout) {
Ok(Ok(t)) => t,
Ok(Err(e)) => return Err(e),
Err(_) => {
shutdown.store(true, Ordering::SeqCst);
return Err(PunktfunkError::Timeout);
}
};
Ok(NativeClient {
frames: frame_rx,
audio: audio_rx,
rumble: rumble_rx,
input_tx,
shutdown,
worker: Some(worker),
mode: negotiated,
host_fingerprint: fingerprint,
})
}
/// Pull the next reassembled, FEC-recovered access unit; [`PunktfunkError::NoFrame`] on
/// timeout, [`PunktfunkError::Closed`]-class errors once the session ended.
///
/// Plane concurrency: each pull method drains its own queue, so video, audio and
/// rumble may each be pulled from their own thread — but at most one thread per plane
/// (`&self` here supports the cross-plane sharing; a plane's queue is still
/// single-consumer by contract).
pub fn next_frame(&self, timeout: Duration) -> Result<Frame> {
match self.frames.recv_timeout(timeout) {
Ok(f) => Ok(f),
Err(RecvTimeoutError::Timeout) => Err(PunktfunkError::NoFrame),
Err(RecvTimeoutError::Disconnected) => Err(PunktfunkError::Closed),
}
}
/// Pull the next Opus audio packet; [`PunktfunkError::NoFrame`] on timeout,
/// [`PunktfunkError::Closed`] once the session ended. Drain on a dedicated audio thread —
/// packets arrive every 5 ms.
pub fn next_audio(&self, timeout: Duration) -> Result<AudioPacket> {
match self.audio.recv_timeout(timeout) {
Ok(p) => Ok(p),
Err(RecvTimeoutError::Timeout) => Err(PunktfunkError::NoFrame),
Err(RecvTimeoutError::Disconnected) => Err(PunktfunkError::Closed),
}
}
/// Pull the next rumble update `(pad, low, high)`; same semantics as
/// [`NativeClient::next_audio`]. Amplitudes are 0..0xFFFF, `(0, 0)` = stop.
pub fn next_rumble(&self, timeout: Duration) -> Result<(u16, u16, u16)> {
match self.rumble.recv_timeout(timeout) {
Ok(r) => Ok(r),
Err(RecvTimeoutError::Timeout) => Err(PunktfunkError::NoFrame),
Err(RecvTimeoutError::Disconnected) => Err(PunktfunkError::Closed),
}
}
/// Queue one input event for delivery as a QUIC datagram.
pub fn send_input(&self, ev: &InputEvent) -> Result<()> {
self.input_tx.send(*ev).map_err(|_| PunktfunkError::Closed)
}
}
impl Drop for NativeClient {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::SeqCst);
if let Some(w) = self.worker.take() {
let _ = w.join();
}
}
}
struct WorkerArgs {
host: String,
port: u16,
mode: Mode,
pin: Option<[u8; 32]>,
frame_tx: SyncSender<Frame>,
audio_tx: SyncSender<AudioPacket>,
rumble_tx: SyncSender<(u16, u16, u16)>,
input_rx: tokio::sync::mpsc::UnboundedReceiver<InputEvent>,
ready_tx: std::sync::mpsc::Sender<Result<(Mode, [u8; 32])>>,
shutdown: Arc<AtomicBool>,
}
/// The worker: QUIC handshake, then the input/datagram tasks + the blocking data-plane pump.
async fn worker_main(args: WorkerArgs) {
let WorkerArgs {
host,
port,
mode,
pin,
frame_tx,
audio_tx,
rumble_tx,
mut input_rx,
ready_tx,
shutdown,
} = args;
let setup = async {
let remote: std::net::SocketAddr = format!("{host}:{port}")
.parse()
.map_err(|_| PunktfunkError::InvalidArg("host:port"))?;
let (ep, observed) = endpoint::client_pinned(pin);
let ep = ep.map_err(|e| PunktfunkError::Io(std::io::Error::other(e.to_string())))?;
let conn = ep
.connect(remote, "punktfunk")
.map_err(|_| PunktfunkError::InvalidArg("connect"))?
.await
.map_err(|e| {
// A pin mismatch surfaces as a TLS failure; report it as a crypto error so
// the embedder can distinguish "wrong host identity" from plain IO trouble.
let fp_mismatch = pin.is_some()
&& observed.lock().unwrap().map(|fp| Some(fp) != pin) == Some(true);
if fp_mismatch {
PunktfunkError::Crypto
} else {
PunktfunkError::Io(std::io::Error::other(e.to_string()))
}
})?;
let fingerprint = observed.lock().unwrap().unwrap_or([0u8; 32]);
let (mut send, mut recv) = conn
.open_bi()
.await
.map_err(|e| PunktfunkError::Io(std::io::Error::other(e.to_string())))?;
io::write_msg(
&mut send,
&Hello {
abi_version: crate::ABI_VERSION,
mode,
}
.encode(),
)
.await?;
let welcome = Welcome::decode(&io::read_msg(&mut recv).await?)?;
// Reserve our data-plane port, then start the host.
let probe = std::net::UdpSocket::bind("0.0.0.0:0")?;
let udp_port = probe.local_addr()?.port();
drop(probe);
io::write_msg(
&mut send,
&Start {
client_udp_port: udp_port,
}
.encode(),
)
.await?;
let host_udp = std::net::SocketAddr::new(remote.ip(), welcome.udp_port);
let transport =
UdpTransport::connect(&format!("0.0.0.0:{udp_port}"), &host_udp.to_string())?;
let session = Session::new(welcome.session_config(Role::Client), Box::new(transport))?;
Ok::<_, PunktfunkError>((conn, session, welcome.mode, fingerprint))
};
let (conn, mut session, negotiated, fingerprint) = match setup.await {
Ok(t) => t,
Err(e) => {
let _ = ready_tx.send(Err(e));
return;
}
};
let _ = ready_tx.send(Ok((negotiated, fingerprint)));
// Input task: embedder events → QUIC datagrams.
let input_conn = conn.clone();
tokio::spawn(async move {
while let Some(ev) = input_rx.recv().await {
let _ = input_conn.send_datagram(ev.encode().to_vec().into());
}
});
// Datagram demux: host → client audio/rumble (try_send: a lagging embedder drops the
// newest packet rather than backing up the QUIC receive path).
let dgram_conn = conn.clone();
tokio::spawn(async move {
while let Ok(d) = dgram_conn.read_datagram().await {
match d.first() {
Some(&crate::quic::AUDIO_MAGIC) => {
if let Some((seq, pts_ns, opus)) = crate::quic::decode_audio_datagram(&d) {
let _ = audio_tx.try_send(AudioPacket {
seq,
pts_ns,
data: opus.to_vec(),
});
}
}
Some(&crate::quic::RUMBLE_MAGIC) => {
if let Some(r) = crate::quic::decode_rumble_datagram(&d) {
let _ = rumble_tx.try_send(r);
}
}
_ => {} // unknown tag — a newer host; ignore
}
}
});
// Watch for connection close → stop the pump.
{
let shutdown = shutdown.clone();
let conn = conn.clone();
tokio::spawn(async move {
conn.closed().await;
shutdown.store(true, Ordering::SeqCst);
});
}
// Data-plane pump on a blocking thread: poll the session, hand frames to the embedder.
// try_send drops the newest frame when the embedder lags (freshness over completeness).
let pump_shutdown = shutdown.clone();
let _ = tokio::task::spawn_blocking(move || {
while !pump_shutdown.load(Ordering::SeqCst) {
match session.poll_frame() {
Ok(frame) => {
let _ = frame_tx.try_send(frame);
}
Err(PunktfunkError::NoFrame) => {
std::thread::sleep(Duration::from_micros(300));
}
Err(_) => break,
}
}
})
.await;
conn.close(0u32.into(), b"client closed");
}
+233
View File
@@ -0,0 +1,233 @@
//! Session configuration and protocol/FEC parameters.
use crate::error::{PunktfunkError, Result};
use crate::packet::{CRYPTO_OVERHEAD, HEADER_LEN, MAX_DATAGRAM_BYTES};
use zeroize::Zeroize;
/// Which side of the stream this session drives.
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Role {
Host = 0,
Client = 1,
}
/// Negotiated protocol generation. P1 is GameStream-compatible (GF(2⁸)); P2 is the
/// `punktfunk/1` extension (GF(2¹⁶), multi-block framing, optional QUIC control).
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ProtocolPhase {
P1GameStream = 1,
P2Punktfunk = 2,
}
/// Erasure-coding field. Mirrors the on-wire `fec_scheme` tag.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FecScheme {
/// GF(2⁸) classic RS — Moonlight/GameStream compatible, ≤ 255 shards/block.
Gf8 = 0,
/// GF(2¹⁶) Leopard-RS — SIMD, O(n log n), up to 65535 shards/block.
Gf16 = 1,
}
impl FecScheme {
pub fn from_u8(v: u8) -> Option<FecScheme> {
match v {
0 => Some(FecScheme::Gf8),
1 => Some(FecScheme::Gf16),
_ => None,
}
}
/// Hard per-block total-shard ceiling for the field (data + recovery).
pub fn max_total_shards(self) -> usize {
match self {
FecScheme::Gf8 => 255,
FecScheme::Gf16 => u16::MAX as usize, // wire fields are u16
}
}
}
/// A client-sized display mode the host should produce on the virtual output.
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Mode {
pub width: u32,
pub height: u32,
pub refresh_hz: u32,
}
/// Per-block FEC parameters. Recovery count is derived from `fec_percent` exactly as
/// GameStream does: `m = ceil(k * fec_percent / 100)`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FecConfig {
pub scheme: FecScheme,
/// Recovery overhead as a percentage of data shards (0 disables FEC).
pub fec_percent: u8,
/// Maximum data shards per FEC block; larger frames split into multiple blocks.
/// GF(2⁸) is bounded at 255 total shards, so keep this ≤ ~200 for `Gf8`.
pub max_data_per_block: u16,
}
impl FecConfig {
/// Recovery (parity) shard count for a block of `data_shards` shards.
pub fn recovery_for(&self, data_shards: usize) -> usize {
if self.fec_percent == 0 || data_shards == 0 {
return 0;
}
// ceil(k * pct / 100)
(data_shards * self.fec_percent as usize).div_ceil(100)
}
}
/// Largest shard payload that still fits a datagram once header + crypto overhead are
/// added. Bounds `shard_payload` so packets never exceed [`MAX_DATAGRAM_BYTES`].
pub const fn max_shard_payload() -> usize {
MAX_DATAGRAM_BYTES - HEADER_LEN - CRYPTO_OVERHEAD
}
/// Everything needed to construct a [`Session`](crate::session::Session).
///
/// `Debug` is implemented by hand to redact `key`/`salt`, and `key`/`salt` are zeroized
/// on drop, so secrets neither leak into logs nor linger in freed memory.
#[derive(Clone)]
pub struct Config {
pub role: Role,
pub phase: ProtocolPhase,
pub fec: FecConfig,
/// Shard payload bytes per packet. Must be even and ≤ [`max_shard_payload`].
pub shard_payload: usize,
/// Largest encoded access unit the reassembler will accept (bounds memory against
/// hostile/corrupt headers; see [`Session`](crate::session::Session)).
pub max_frame_bytes: usize,
pub encrypt: bool,
/// AES-128 session key established during pairing. MUST be unique per session when
/// `encrypt` is set (see the nonce-uniqueness contract in [`crate::crypto`]).
pub key: [u8; 16],
/// Per-session nonce salt, established alongside `key` during pairing. MUST be
/// unique per (key, session).
pub salt: [u8; 4],
/// Test hook: when non-zero, the loopback transport deterministically drops one of
/// every `loopback_drop_period` packets it sends. 0 = lossless.
pub loopback_drop_period: u32,
}
impl Drop for Config {
fn drop(&mut self) {
self.key.zeroize();
self.salt.zeroize();
}
}
impl std::fmt::Debug for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Config")
.field("role", &self.role)
.field("phase", &self.phase)
.field("fec", &self.fec)
.field("shard_payload", &self.shard_payload)
.field("max_frame_bytes", &self.max_frame_bytes)
.field("encrypt", &self.encrypt)
.field("key", &"<redacted>")
.field("salt", &"<redacted>")
.field("loopback_drop_period", &self.loopback_drop_period)
.finish()
}
}
impl Config {
/// Validate every invariant the hot path and the reassembler rely on. Rejecting here
/// is what keeps the receive-side parser's allocations bounded.
pub fn validate(&self) -> Result<()> {
if self.shard_payload == 0 || self.shard_payload % 2 != 0 {
return Err(PunktfunkError::InvalidArg(
"shard_payload must be even and > 0",
));
}
if self.shard_payload > max_shard_payload() {
return Err(PunktfunkError::InvalidArg(
"shard_payload too large to fit a datagram (header + crypto overhead)",
));
}
if self.fec.max_data_per_block == 0 {
return Err(PunktfunkError::InvalidArg("max_data_per_block must be > 0"));
}
// The per-block total (data + recovery) must fit both the field ceiling and the
// u16 wire fields.
let k = self.fec.max_data_per_block as usize;
let total = k + self.fec.recovery_for(k);
if total > self.fec.scheme.max_total_shards() {
return Err(PunktfunkError::InvalidArg(
"max_data_per_block + recovery exceeds the FEC scheme's shard ceiling",
));
}
if self.max_frame_bytes == 0 {
return Err(PunktfunkError::InvalidArg("max_frame_bytes must be > 0"));
}
// The frame must not need more FEC blocks than the u16 block-count field allows.
let total_data = self.max_frame_bytes.div_ceil(self.shard_payload).max(1);
let max_blocks = total_data.div_ceil(k).max(1);
if max_blocks > u16::MAX as usize {
return Err(PunktfunkError::InvalidArg(
"max_frame_bytes too large for this shard/block configuration (block count overflows u16)",
));
}
if self.encrypt && self.key == [0u8; 16] {
return Err(PunktfunkError::InvalidArg(
"encrypt requires a non-zero session key (see crypto nonce-uniqueness contract)",
));
}
Ok(())
}
/// Sensible P1 defaults: GF(2⁸), 15% FEC, ~1 KiB shards, no encryption, 64 MiB frame
/// cap. When enabling encryption, replace `key`/`salt` with per-session values from
/// pairing — the all-zero defaults are rejected by [`validate`](Self::validate).
pub fn p1_defaults(role: Role) -> Self {
Config {
role,
phase: ProtocolPhase::P1GameStream,
fec: FecConfig {
scheme: FecScheme::Gf8,
fec_percent: 15,
max_data_per_block: 200,
},
shard_payload: 1024,
max_frame_bytes: 64 * 1024 * 1024,
encrypt: false,
key: [0u8; 16],
salt: [0u8; 4],
loopback_drop_period: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_encrypt_with_zero_key() {
let mut c = Config::p1_defaults(Role::Host);
c.encrypt = true; // key is still all-zero
assert!(c.validate().is_err());
c.key = [1u8; 16];
assert!(c.validate().is_ok());
}
#[test]
fn rejects_oversized_shard_payload() {
let mut c = Config::p1_defaults(Role::Host);
c.shard_payload = max_shard_payload() + 2; // still even, but won't fit a datagram
assert!(c.validate().is_err());
}
#[test]
fn rejects_block_exceeding_scheme_ceiling() {
let mut c = Config::p1_defaults(Role::Host); // Gf8, ceiling 255
c.fec.max_data_per_block = 250;
c.fec.fec_percent = 15; // 250 + ceil(250*15/100)=288 > 255
assert!(c.validate().is_err());
}
}
+149
View File
@@ -0,0 +1,149 @@
//! AES-128-GCM session sealing, matching GameStream's video crypto in P1.
//!
//! ## Nonce uniqueness (the GCM safety requirement)
//!
//! The 96-bit nonce is `salt (4 bytes) || sequence (8 bytes, big-endian)`. Reusing a
//! `(key, nonce)` pair under AES-GCM is catastrophic, so two precautions apply:
//!
//! 1. **Per-direction salts.** Host and client share one `key` and `salt`, and each
//! counts its sequence from 0. To stop the host's video stream and the client's input
//! stream from colliding on `(key, nonce)`, the top bit of `salt[0]` is set to the
//! sender's direction — so the two directions occupy disjoint nonce spaces.
//! 2. **Per-session key+salt.** The pairing layer MUST hand each session a fresh
//! `(key, salt)`; reusing them across sessions reintroduces nonce reuse. `Config`'s
//! all-zero key with `encrypt = true` is rejected by `Config::validate` to catch the
//! obvious footgun.
//!
//! The sequence number is also passed as AEAD associated data, so tampering with the
//! on-wire sequence is detected (the tag check fails) rather than silently shifting the
//! nonce. Note: this layer does not provide anti-replay — see `Session`.
use crate::config::Role;
use crate::error::{PunktfunkError, Result};
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes128Gcm, Key, Nonce};
/// 16-byte AEAD authentication tag appended by GCM.
pub const TAG_LEN: usize = 16;
pub struct SessionCrypto {
cipher: Aes128Gcm,
/// Salt for nonces we seal with (our direction).
send_salt: [u8; 4],
/// Salt for nonces we open with (the peer's direction).
recv_salt: [u8; 4],
}
impl SessionCrypto {
pub fn new(key: &[u8; 16], salt: [u8; 4], role: Role) -> Self {
let key = Key::<Aes128Gcm>::from_slice(key);
let own = direction(role);
SessionCrypto {
cipher: Aes128Gcm::new(key),
send_salt: dir_salt(salt, own),
recv_salt: dir_salt(salt, own ^ 1),
}
}
/// Seal `plaintext` for sequence `seq`, returning `ciphertext || tag`. `seq` is
/// authenticated as associated data.
pub fn seal(&self, seq: u64, plaintext: &[u8]) -> Result<Vec<u8>> {
let nonce = nonce(self.send_salt, seq);
self.cipher
.encrypt(
Nonce::from_slice(&nonce),
Payload {
msg: plaintext,
aad: &seq.to_be_bytes(),
},
)
.map_err(|_| PunktfunkError::Crypto)
}
/// Open `ciphertext || tag` for sequence `seq` (also bound as associated data).
pub fn open(&self, seq: u64, ciphertext: &[u8]) -> Result<Vec<u8>> {
let nonce = nonce(self.recv_salt, seq);
self.cipher
.decrypt(
Nonce::from_slice(&nonce),
Payload {
msg: ciphertext,
aad: &seq.to_be_bytes(),
},
)
.map_err(|_| PunktfunkError::Crypto)
}
}
fn direction(role: Role) -> u8 {
match role {
Role::Host => 0,
Role::Client => 1,
}
}
/// Fold a 1-bit direction into the salt (top bit of `salt[0]`) so the two directions of
/// a session never share a nonce under the same key.
fn dir_salt(mut salt: [u8; 4], dir: u8) -> [u8; 4] {
salt[0] = (salt[0] & 0x7f) | (dir << 7);
salt
}
fn nonce(salt: [u8; 4], seq: u64) -> [u8; 12] {
let mut n = [0u8; 12];
n[..4].copy_from_slice(&salt);
n[4..].copy_from_slice(&seq.to_be_bytes());
n
}
/// Generate a fresh random AES-128 session key (control-plane / pairing use).
pub fn random_key() -> [u8; 16] {
let mut k = [0u8; 16];
rand::RngCore::fill_bytes(&mut rand::rng(), &mut k);
k
}
/// Generate a fresh random per-session nonce salt.
pub fn random_salt() -> [u8; 4] {
let mut s = [0u8; 4];
rand::RngCore::fill_bytes(&mut rand::rng(), &mut s);
s
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn seal_open_roundtrip_cross_direction() {
let key = random_key();
let salt = random_salt();
let host = SessionCrypto::new(&key, salt, Role::Host);
let client = SessionCrypto::new(&key, salt, Role::Client);
let msg = b"the quick brown fox";
let sealed = host.seal(42, msg).unwrap(); // host -> client (video direction)
assert_ne!(&sealed[..msg.len()], &msg[..]); // actually encrypted
assert_eq!(sealed.len(), msg.len() + TAG_LEN);
assert_eq!(client.open(42, &sealed).unwrap(), msg);
// Wrong sequence (nonce + AAD) → authentication failure.
assert!(client.open(43, &sealed).is_err());
// Direction separation: the host opens with the peer (client) salt, so it cannot
// open its own outbound packet → distinct nonce spaces per direction.
assert!(host.open(42, &sealed).is_err());
}
#[test]
fn directions_use_distinct_nonce_spaces() {
let key = random_key();
let salt = [0u8; 4]; // even an all-zero base salt must separate the directions
let host = SessionCrypto::new(&key, salt, Role::Host);
let client = SessionCrypto::new(&key, salt, Role::Client);
// Same seq, same key, opposite directions → different ciphertext (no reuse).
assert_ne!(
host.seal(0, b"abc").unwrap(),
client.seal(0, b"abc").unwrap()
);
}
}
+64
View File
@@ -0,0 +1,64 @@
//! Error type and the stable C ABI status codes it maps to.
use thiserror::Error;
/// The core's internal error type. Crosses the C ABI as a [`PunktfunkStatus`] code.
#[derive(Debug, Error)]
pub enum PunktfunkError {
#[error("invalid argument: {0}")]
InvalidArg(&'static str),
#[error("fec error: {0}")]
Fec(#[from] crate::fec::FecError),
#[error("crypto seal/open failed")]
Crypto,
#[error("malformed packet")]
BadPacket,
#[error("no complete frame available yet")]
NoFrame,
#[error("unsupported: {0}")]
Unsupported(&'static str),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("timed out")]
Timeout,
#[error("session closed")]
Closed,
}
pub type Result<T> = core::result::Result<T, PunktfunkError>;
/// Stable C ABI status codes. `Ok` is 0; all errors are negative so callers can
/// test `rc < 0`. Do not renumber existing variants — only append.
#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PunktfunkStatus {
Ok = 0,
InvalidArg = -1,
Fec = -2,
Crypto = -3,
BadPacket = -4,
NoFrame = -5,
Unsupported = -6,
Io = -7,
NullPointer = -8,
Timeout = -9,
Closed = -10,
Panic = -99,
}
impl PunktfunkError {
/// Map to the C ABI status code.
pub fn status(&self) -> PunktfunkStatus {
match self {
PunktfunkError::InvalidArg(_) => PunktfunkStatus::InvalidArg,
PunktfunkError::Fec(_) => PunktfunkStatus::Fec,
PunktfunkError::Crypto => PunktfunkStatus::Crypto,
PunktfunkError::BadPacket => PunktfunkStatus::BadPacket,
PunktfunkError::NoFrame => PunktfunkStatus::NoFrame,
PunktfunkError::Unsupported(_) => PunktfunkStatus::Unsupported,
PunktfunkError::Io(_) => PunktfunkStatus::Io,
PunktfunkError::Timeout => PunktfunkStatus::Timeout,
PunktfunkError::Closed => PunktfunkStatus::Closed,
}
}
}
+84
View File
@@ -0,0 +1,84 @@
//! GF(2¹⁶) Leopard-RS backend (`reed-solomon-simd`). SIMD, O(n log n), up to 65535
//! shards/block — this is what removes the GameStream 255-shard / ~1 Gbps wall.
//! Shard length must be even.
use super::{validate_block_shape, validate_encode_shape, ErasureCoder, FecError};
use crate::config::FecScheme;
pub struct Gf16Coder;
impl ErasureCoder for Gf16Coder {
fn scheme(&self) -> FecScheme {
FecScheme::Gf16
}
fn encode(&self, data: &[Vec<u8>], recovery_count: usize) -> Result<Vec<Vec<u8>>, FecError> {
if recovery_count == 0 {
return Ok(Vec::new());
}
validate_encode_shape(data)?;
let k = data.len();
if data[0].len() % 2 != 0 {
return Err(FecError::Config("GF(2^16) shard length must be even"));
}
reed_solomon_simd::encode(k, recovery_count, data)
.map_err(|_| FecError::Backend("gf16 encode"))
}
fn reconstruct(
&self,
data_count: usize,
recovery_count: usize,
received: &mut [Option<Vec<u8>>],
) -> Result<Vec<Vec<u8>>, FecError> {
validate_block_shape(received, data_count, recovery_count)?;
let present = received.iter().filter(|s| s.is_some()).count();
if present < data_count {
return Err(FecError::TooFewShards {
have: present,
need: data_count,
});
}
// Fast path: all originals already present, or FEC disabled.
let originals_complete = received[..data_count].iter().all(|s| s.is_some());
if recovery_count == 0 || originals_complete {
let mut out = Vec::with_capacity(data_count);
for slot in received.iter().take(data_count) {
out.push(slot.clone().ok_or(FecError::TooFewShards {
have: present,
need: data_count,
})?);
}
return Ok(out);
}
// Hand the decoder the surviving originals and recovery shards, indexed.
let original_in: Vec<(usize, &[u8])> = received[..data_count]
.iter()
.enumerate()
.filter_map(|(i, s)| s.as_deref().map(|b| (i, b)))
.collect();
let recovery_in: Vec<(usize, &[u8])> = received[data_count..data_count + recovery_count]
.iter()
.enumerate()
.filter_map(|(j, s)| s.as_deref().map(|b| (j, b)))
.collect();
let restored =
reed_solomon_simd::decode(data_count, recovery_count, original_in, recovery_in)
.map_err(|_| FecError::Backend("gf16 decode"))?;
// Merge surviving originals with the recovered ones.
let mut out: Vec<Vec<u8>> = Vec::with_capacity(data_count);
for (i, slot) in received[..data_count].iter().enumerate() {
if let Some(s) = slot {
out.push(s.clone());
} else if let Some(s) = restored.get(&i) {
out.push(s.clone());
} else {
return Err(FecError::Backend("gf16 decode left an original missing"));
}
}
Ok(out)
}
}
+140
View File
@@ -0,0 +1,140 @@
//! GF(2⁸) classic ReedSolomon backend (vendored `fec-rs`). Uses the **Cauchy** generator
//! matrix `M[j][i] = inv[(m+i)^j]` over GF(2⁸) (poly 0x1d) — byte-identical to the `nanors`
//! library Moonlight uses, so the parity this produces is recoverable by a stock Moonlight
//! client (unlike Vandermonde RS, whose parity is not interoperable). Hard ceiling: data +
//! recovery ≤ 255 shards/block.
use super::{validate_block_shape, validate_encode_shape, ErasureCoder, FecError};
use crate::config::FecScheme;
use fec_rs::ReedSolomon;
pub struct Gf8Coder;
impl ErasureCoder for Gf8Coder {
fn scheme(&self) -> FecScheme {
FecScheme::Gf8
}
fn encode(&self, data: &[Vec<u8>], recovery_count: usize) -> Result<Vec<Vec<u8>>, FecError> {
if recovery_count == 0 {
return Ok(Vec::new());
}
validate_encode_shape(data)?;
let k = data.len();
let shard_len = data[0].len();
let rs = ReedSolomon::new(k, recovery_count)
.map_err(|_| FecError::Config("invalid GF(2^8) shard counts"))?;
// fec-rs fills parity in place: shards = data || zeroed parity.
let mut shards: Vec<Vec<u8>> = Vec::with_capacity(k + recovery_count);
shards.extend_from_slice(data);
shards.resize_with(k + recovery_count, || vec![0u8; shard_len]);
rs.encode(&mut shards)
.map_err(|_| FecError::Backend("gf8 encode"))?;
Ok(shards.split_off(k))
}
fn reconstruct(
&self,
data_count: usize,
recovery_count: usize,
received: &mut [Option<Vec<u8>>],
) -> Result<Vec<Vec<u8>>, FecError> {
validate_block_shape(received, data_count, recovery_count)?;
let present = received.iter().filter(|s| s.is_some()).count();
if present < data_count {
return Err(FecError::TooFewShards {
have: present,
need: data_count,
});
}
if recovery_count == 0 {
// No FEC: every original must already be present.
return collect_originals(received, data_count);
}
let rs = ReedSolomon::new(data_count, recovery_count)
.map_err(|_| FecError::Config("invalid GF(2^8) shard counts"))?;
rs.reconstruct_data(received)
.map_err(|_| FecError::Backend("gf8 reconstruct"))?;
collect_originals(received, data_count)
}
}
fn collect_originals(
received: &[Option<Vec<u8>>],
data_count: usize,
) -> Result<Vec<Vec<u8>>, FecError> {
let mut out = Vec::with_capacity(data_count);
for slot in received.iter().take(data_count) {
out.push(
slot.clone()
.ok_or(FecError::Backend("reconstruction left an original missing"))?,
);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
/// Locks byte-exact compatibility with Moonlight's `nanors` (Cauchy matrix
/// `M[j][i] = inv[(m+i)^j]`, GF(2⁸) poly 0x1d). If the backend ever switched matrices,
/// these vectors would break and our parity would no longer be Moonlight-decodable.
#[test]
fn nanors_exact_parity_vectors() {
let coder = Gf8Coder;
// The definitive nanors vector (k=4, m=2): single-byte shards [10,20,30,40] → [136, 0].
let data = vec![vec![10u8], vec![20], vec![30], vec![40]];
let parity = coder.encode(&data, 2).unwrap();
assert_eq!(parity, vec![vec![136u8], vec![0u8]]);
// Cross-check independently from the Cauchy parity rows (proves the matrix, not just a
// memorized output): parity[j] = XOR_i M[j][i] · data[i] over GF(2⁸).
let rows = [[142u8, 244, 71, 167], [244, 142, 167, 71]];
let din = [10u8, 20, 30, 40];
for (j, row) in rows.iter().enumerate() {
let expect = row
.iter()
.zip(din)
.fold(0u8, |acc, (&m, d)| acc ^ gf_mul(m, d));
assert_eq!(parity[j][0], expect, "parity row {j}");
}
}
/// Round-trip: erase `m` data shards and confirm reconstruction recovers the originals.
#[test]
fn recovers_erased_data_shards() {
let coder = Gf8Coder;
let data: Vec<Vec<u8>> = (0..6).map(|i| vec![i as u8; 8]).collect();
let parity = coder.encode(&data, 3).unwrap();
let mut received: Vec<Option<Vec<u8>>> = data
.iter()
.cloned()
.map(Some)
.chain(parity.into_iter().map(Some))
.collect();
// Erase 3 data shards (the FEC budget) + nothing else.
received[1] = None;
received[3] = None;
received[5] = None;
let recovered = coder.reconstruct(6, 3, &mut received).unwrap();
assert_eq!(recovered, data);
}
/// GF(2⁸) multiply, reduction poly 0x1d — independent of the backend.
fn gf_mul(mut a: u8, mut b: u8) -> u8 {
let mut p = 0u8;
for _ in 0..8 {
if b & 1 != 0 {
p ^= a;
}
let hi = a & 0x80;
a <<= 1;
if hi != 0 {
a ^= 0x1d;
}
b >>= 1;
}
p
}
}
+167
View File
@@ -0,0 +1,167 @@
//! Erasure coding. Two backends behind one [`ErasureCoder`] trait: GF(2⁸) (classic
//! ReedSolomon, Moonlight-compatible, P1) and GF(2¹⁶) Leopard-RS (the wall-breaker, P2).
//!
//! The wall this breaks: GameStream's GF(2⁸) RS caps a block at 255 shards, which at
//! 5120×1440@240 is hit around 1 Gbps. GF(2¹⁶) raises that ceiling to 65535 shards and
//! runs in O(n log n) with SIMD, so the per-frame shard count stops being the limiter.
mod gf16;
mod gf8;
pub use gf16::Gf16Coder;
pub use gf8::Gf8Coder;
use crate::config::FecScheme;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum FecError {
#[error("invalid shard configuration: {0}")]
Config(&'static str),
#[error("too few shards to reconstruct (have {have}, need {need})")]
TooFewShards { have: usize, need: usize },
#[error("backend error: {0}")]
Backend(&'static str),
}
/// Backend-agnostic erasure coder. All shards in a block are equal length.
pub trait ErasureCoder: Send + Sync {
fn scheme(&self) -> FecScheme;
/// Encode `data` (K original shards) into `recovery_count` (M) parity shards.
/// Returns the M recovery shards. `recovery_count == 0` returns an empty `Vec`.
fn encode(&self, data: &[Vec<u8>], recovery_count: usize) -> Result<Vec<Vec<u8>>, FecError>;
/// Reconstruct the K original shards. `received` has length K+M: indices `0..K` are
/// originals, `K..K+M` are recovery shards; `Some` = present, `None` = lost.
/// Returns the K original shards in order.
fn reconstruct(
&self,
data_count: usize,
recovery_count: usize,
received: &mut [Option<Vec<u8>>],
) -> Result<Vec<Vec<u8>>, FecError>;
}
/// Construct the coder for a scheme.
pub fn coder_for(scheme: FecScheme) -> Box<dyn ErasureCoder> {
match scheme {
FecScheme::Gf8 => Box::new(Gf8Coder),
FecScheme::Gf16 => Box::new(Gf16Coder),
}
}
/// Validate the shape `reconstruct` promises: `received.len() == data + recovery`, and
/// every present shard shares one length. Both backends call this first so neither the
/// fast path nor a malformed caller can slip mismatched-length or wrong-count shards
/// through (the fast paths bypass the backend's own length checks otherwise).
pub(crate) fn validate_block_shape(
received: &[Option<Vec<u8>>],
data_count: usize,
recovery_count: usize,
) -> Result<(), FecError> {
if received.len() != data_count + recovery_count {
return Err(FecError::Config(
"received length must equal data + recovery",
));
}
let mut len = None;
for s in received.iter().flatten() {
match len {
None => len = Some(s.len()),
Some(l) if l != s.len() => {
return Err(FecError::Config("shards in a block must be equal length"));
}
_ => {}
}
}
Ok(())
}
/// Validate `encode` inputs: at least one data shard, all of equal length.
pub(crate) fn validate_encode_shape(data: &[Vec<u8>]) -> Result<(), FecError> {
let first = data
.first()
.ok_or(FecError::Config("no data shards"))?
.len();
if data.iter().any(|s| s.len() != first) {
return Err(FecError::Config("data shards must be equal length"));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
/// Round-trip a block through a coder, losing exactly `lose` shards (some data,
/// some recovery), and assert the originals come back byte-identical.
fn roundtrip(coder: &dyn ErasureCoder, k: usize, m: usize, shard_len: usize, lose: &[usize]) {
let data: Vec<Vec<u8>> = (0..k)
.map(|i| (0..shard_len).map(|b| (i * 31 + b * 7) as u8).collect())
.collect();
let recovery = coder.encode(&data, m).unwrap();
assert_eq!(recovery.len(), m);
let mut received: Vec<Option<Vec<u8>>> = Vec::with_capacity(k + m);
received.extend(data.iter().cloned().map(Some));
received.extend(recovery.iter().cloned().map(Some));
for &idx in lose {
received[idx] = None;
}
let restored = coder.reconstruct(k, m, &mut received).unwrap();
assert_eq!(restored, data);
}
#[test]
fn gf8_recovers_within_budget() {
// 16 data + 4 recovery; lose 2 data + 2 recovery (== budget).
roundtrip(&Gf8Coder, 16, 4, 256, &[0, 7, 16, 19]);
}
#[test]
fn gf16_recovers_within_budget() {
roundtrip(&Gf16Coder, 16, 4, 256, &[1, 9, 17, 18]);
}
#[test]
fn gf8_too_much_loss_errors() {
let data: Vec<Vec<u8>> = (0..8).map(|_| vec![0u8; 64]).collect();
let recovery = Gf8Coder.encode(&data, 2).unwrap();
let mut received: Vec<Option<Vec<u8>>> = data
.iter()
.cloned()
.map(Some)
.chain(recovery.into_iter().map(Some))
.collect();
// Lose 3 with only 2 recovery shards → unrecoverable.
received[0] = None;
received[1] = None;
received[2] = None;
assert!(Gf16Coder.scheme() == FecScheme::Gf16);
let err = Gf8Coder.reconstruct(8, 2, &mut received);
assert!(err.is_err());
}
#[test]
fn reconstruct_rejects_wrong_received_length() {
// data=2, recovery=2 expects a 4-element slice; a 3-element one must error, not
// panic on the recovery-slice index (both backends).
let mut recv: Vec<Option<Vec<u8>>> = vec![Some(vec![0u8; 8]), None, Some(vec![0u8; 8])];
assert!(Gf16Coder.reconstruct(2, 2, &mut recv).is_err());
let mut recv: Vec<Option<Vec<u8>>> = vec![Some(vec![0u8; 8]), None, Some(vec![0u8; 8])];
assert!(Gf8Coder.reconstruct(2, 2, &mut recv).is_err());
}
#[test]
fn reconstruct_rejects_mismatched_shard_lengths() {
// The GF16 fast path used to clone shards verbatim without a length check.
let mut recv: Vec<Option<Vec<u8>>> =
vec![Some(vec![0u8; 8]), Some(vec![0u8; 6]), None, None];
assert!(Gf16Coder.reconstruct(2, 2, &mut recv).is_err());
let mut recv: Vec<Option<Vec<u8>>> =
vec![Some(vec![0u8; 8]), Some(vec![0u8; 6]), None, None];
assert!(Gf8Coder.reconstruct(2, 2, &mut recv).is_err());
}
}
+151
View File
@@ -0,0 +1,151 @@
//! Input events flowing client → host (and the host-side receive callback).
//!
//! Input rides the same transport as video but on its own wire tag
//! ([`INPUT_MAGIC`]), so a session can demultiplex video from input by the first byte.
/// Wire tag distinguishing an input datagram from a video packet.
pub const INPUT_MAGIC: u8 = 0xC8;
/// Fixed serialized size of an [`InputEvent`] on the wire (tag + fields).
pub const INPUT_WIRE_LEN: usize = 1 + 1 + 4 + 4 + 4 + 4; // = 18
/// Kinds of input event. `#[repr(u8)]` so it crosses the C ABI as a byte tag.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InputKind {
KeyDown = 0,
KeyUp = 1,
/// Relative motion: `x`/`y` carry `dx`/`dy`.
MouseMove = 2,
/// Absolute motion: `x`/`y` carry pixel coordinates.
MouseMoveAbs = 3,
MouseButtonDown = 4,
MouseButtonUp = 5,
/// `x` carries the (signed) scroll delta.
MouseScroll = 6,
/// `code` = button bit ([`gamepad`] `BTN_*`), `x` ≠ 0 = pressed, `flags` = pad index.
GamepadButton = 7,
/// `code` = axis id ([`gamepad`] `AXIS_*`), `x` = axis value, `flags` = pad index.
/// Sticks are i16 range (32768..32767) in the XInput/Moonlight convention — **+y =
/// up** (unlike mouse coordinates); triggers 0..255.
GamepadAxis = 8,
}
/// The gamepad wire contract for [`InputKind::GamepadButton`]/[`InputKind::GamepadAxis`].
///
/// Everything follows the GameStream/XInput conventions end to end: buttons reuse
/// GameStream's `buttonFlags` bit positions, sticks are 32768..32767 with **+y = up**,
/// triggers 0..255 (what Moonlight sends and what the host's virtual xpad already
/// consumes). One event carries one transition: `code` = the bit below, `x` = 1 pressed /
/// 0 released. Axes are sent individually; the host accumulates per-pad state and emits
/// one evdev SYN per event.
pub mod gamepad {
pub const BTN_DPAD_UP: u32 = 0x0001;
pub const BTN_DPAD_DOWN: u32 = 0x0002;
pub const BTN_DPAD_LEFT: u32 = 0x0004;
pub const BTN_DPAD_RIGHT: u32 = 0x0008;
pub const BTN_START: u32 = 0x0010;
pub const BTN_BACK: u32 = 0x0020;
pub const BTN_LS_CLICK: u32 = 0x0040;
pub const BTN_RS_CLICK: u32 = 0x0080;
pub const BTN_LB: u32 = 0x0100;
pub const BTN_RB: u32 = 0x0200;
pub const BTN_GUIDE: u32 = 0x0400;
pub const BTN_A: u32 = 0x1000;
pub const BTN_B: u32 = 0x2000;
pub const BTN_X: u32 = 0x4000;
pub const BTN_Y: u32 = 0x8000;
/// Axis ids for `InputKind::GamepadAxis`.
pub const AXIS_LS_X: u32 = 0;
pub const AXIS_LS_Y: u32 = 1;
pub const AXIS_RS_X: u32 = 2;
pub const AXIS_RS_Y: u32 = 3;
/// Triggers: value range 0..255.
pub const AXIS_LT: u32 = 4;
pub const AXIS_RT: u32 = 5;
}
impl InputKind {
pub fn from_u8(v: u8) -> Option<InputKind> {
use InputKind::*;
Some(match v {
0 => KeyDown,
1 => KeyUp,
2 => MouseMove,
3 => MouseMoveAbs,
4 => MouseButtonDown,
5 => MouseButtonUp,
6 => MouseScroll,
7 => GamepadButton,
8 => GamepadAxis,
_ => return None,
})
}
}
/// A single input event. `#[repr(C)]` — shared verbatim with the C ABI as
/// `PunktfunkInputEvent`.
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct InputEvent {
pub kind: InputKind,
pub _pad: [u8; 3],
/// keycode / button id / axis id, depending on `kind`.
pub code: u32,
/// x / dx / abs-x / axis-value / scroll-delta, depending on `kind`.
pub x: i32,
/// y / dy / abs-y, depending on `kind`.
pub y: i32,
/// modifier bitmask or gamepad index.
pub flags: u32,
}
impl InputEvent {
/// Serialize to the fixed wire layout (`INPUT_MAGIC` + little-endian fields).
pub fn encode(&self) -> [u8; INPUT_WIRE_LEN] {
let mut b = [0u8; INPUT_WIRE_LEN];
b[0] = INPUT_MAGIC;
b[1] = self.kind as u8;
b[2..6].copy_from_slice(&self.code.to_le_bytes());
b[6..10].copy_from_slice(&self.x.to_le_bytes());
b[10..14].copy_from_slice(&self.y.to_le_bytes());
b[14..18].copy_from_slice(&self.flags.to_le_bytes());
b
}
/// Parse from the wire layout. Returns `None` on bad tag/length/kind.
pub fn decode(buf: &[u8]) -> Option<InputEvent> {
if buf.len() < INPUT_WIRE_LEN || buf[0] != INPUT_MAGIC {
return None;
}
let kind = InputKind::from_u8(buf[1])?;
Some(InputEvent {
kind,
_pad: [0; 3],
code: u32::from_le_bytes(buf[2..6].try_into().unwrap()),
x: i32::from_le_bytes(buf[6..10].try_into().unwrap()),
y: i32::from_le_bytes(buf[10..14].try_into().unwrap()),
flags: u32::from_le_bytes(buf[14..18].try_into().unwrap()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn input_wire_roundtrip() {
let e = InputEvent {
kind: InputKind::MouseMove,
_pad: [0; 3],
code: 0,
x: -12,
y: 34,
flags: 0xABCD,
};
assert_eq!(InputEvent::decode(&e.encode()), Some(e));
assert!(InputEvent::decode(&[0u8; INPUT_WIRE_LEN]).is_none()); // bad magic
}
}
+49
View File
@@ -0,0 +1,49 @@
//! # punktfunk-core
//!
//! The shared protocol / transport / FEC core for the punktfunk low-latency streaming
//! stack. It is compiled exactly once and linked by every host and client — directly
//! as a Rust `lib`, or across the [C ABI](crate::abi) by Swift / Kotlin / C clients.
//!
//! Everything platform-specific (capture, encode, decode, present, input injection)
//! lives *outside* this crate. What lives *here*:
//!
//! - [`fec`] — erasure coding. GF(2⁸) for GameStream/Moonlight compatibility (P1) and
//! GF(2¹⁶) Leopard-RS (P2) which removes the ~1 Gbps per-frame shard-count ceiling.
//! - [`packet`] — `#[repr(C)]` zero-copy wire framing: splitting an access unit into
//! FEC blocks of MTU-sized shards and reassembling them on the far side.
//! - [`crypto`] — AES-128-GCM session sealing, matching GameStream in P1.
//! - [`session`] — the host (submit frame → FEC → packetize → seal → send) and client
//! (recv → open → reorder → FEC recover → reassemble) state machines.
//! - [`transport`] — pluggable packet I/O (in-process loopback for tests; UDP for real).
//! - [`abi`] — the `extern "C"` surface and `cbindgen`-generated `punktfunk_core.h`.
//!
//! ## Threading contract
//!
//! Nothing in the per-frame path touches an async runtime. `tokio`/`quinn` are gated
//! behind the off-by-default `quic` feature and used only for the control plane.
#![forbid(unsafe_op_in_unsafe_fn)]
pub mod abi;
#[cfg(feature = "quic")]
pub mod client;
pub mod config;
pub mod crypto;
pub mod error;
pub mod fec;
pub mod input;
pub mod packet;
#[cfg(feature = "quic")]
pub mod quic;
pub mod session;
pub mod stats;
pub mod transport;
pub use config::{Config, FecConfig, FecScheme, Mode, ProtocolPhase, Role};
pub use error::{PunktfunkError, PunktfunkStatus, Result};
pub use session::{Frame, Session};
pub use stats::Stats;
/// Bump on any breaking change to the [C ABI](crate::abi). Mirrors
/// `punktfunk_abi_version()` and is checked by clients before use.
pub const ABI_VERSION: u32 = 1;
+581
View File
@@ -0,0 +1,581 @@
//! Zero-copy wire framing: split an access unit into FEC blocks of MTU-sized shards,
//! and reassemble + FEC-recover them on the far side.
//!
//! ## Wire layout
//!
//! Each packet is a fixed [`PacketHeader`] followed by one FEC shard's payload. Fields
//! are host-endian for now (every target platform is little-endian); the `punktfunk/1` (P2)
//! spec will pin byte order explicitly when we talk to non-LE peers.
//!
//! ## GameStream mapping (P1)
//!
//! `frame_index`↔`frameIndex`, `stream_seq`↔`streamPacketIndex`,
//! (`block_index`,`block_count`)↔the `multiFecBlocks` nibbles, and
//! (`data_shards`,`recovery_shards`,`shard_index`)↔the `fecInfo` bitfield. We carry them
//! as explicit fields rather than bit-packing; full GameStream wire-exactness is an M2
//! concern (it also needs RTP framing + RTSP), this is the coherent internal format.
use crate::config::Config;
use crate::error::{PunktfunkError, Result};
use crate::fec::ErasureCoder;
use crate::session::Frame;
use crate::stats::StatsCounters;
use std::collections::{BTreeMap, HashMap, HashSet};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
/// Identifies a punktfunk video packet (vs. an input datagram, see [`crate::input`]).
pub const PUNKTFUNK_MAGIC: u8 = 0xC9;
// Frame flags (mirroring GameStream's FLAG_*).
pub const FLAG_PIC: u8 = 0x1;
pub const FLAG_EOF: u8 = 0x2;
pub const FLAG_SOF: u8 = 0x4;
/// Crypto framing overhead [`Session`](crate::session::Session) adds when encrypting:
/// an 8-byte sequence prefix plus the GCM tag.
pub const CRYPTO_OVERHEAD: usize = 8 + crate::crypto::TAG_LEN;
/// Largest UDP datagram the core will send or accept. `Config::validate` bounds
/// `shard_payload` so `HEADER_LEN + shard_payload + CRYPTO_OVERHEAD ≤ MAX_DATAGRAM_BYTES`.
pub const MAX_DATAGRAM_BYTES: usize = 2048;
/// How many frames behind the newest the reassembler keeps before pruning stragglers.
const REORDER_WINDOW: u32 = 16;
/// Fixed per-packet header. `#[repr(C)]`, no padding, zero-copy (de)serializable.
#[repr(C)]
#[derive(Clone, Copy, Debug, FromBytes, IntoBytes, KnownLayout, Immutable)]
pub struct PacketHeader {
pub pts_ns: u64,
pub frame_index: u32,
pub stream_seq: u32,
pub frame_bytes: u32,
pub user_flags: u32,
pub block_index: u16,
pub block_count: u16,
pub data_shards: u16,
pub recovery_shards: u16,
pub shard_index: u16,
pub shard_bytes: u16,
pub magic: u8,
pub version: u8,
pub fec_scheme: u8,
pub flags: u8,
}
/// Size of [`PacketHeader`] on the wire (40 bytes).
pub const HEADER_LEN: usize = std::mem::size_of::<PacketHeader>();
const _: () = assert!(HEADER_LEN == 40, "PacketHeader must be 40 bytes / unpadded");
// ---------------------------------------------------------------------------
// Host side: packetization
// ---------------------------------------------------------------------------
/// Splits encoded access units into FEC-protected shard packets. Host-side only.
pub struct Packetizer {
next_frame_index: u32,
next_seq: u32,
shard_payload: usize,
fec: crate::config::FecConfig,
version: u8,
}
impl Packetizer {
pub fn new(config: &Config) -> Self {
Packetizer {
next_frame_index: 0,
next_seq: 0,
shard_payload: config.shard_payload,
fec: config.fec,
version: config.phase as u8,
}
}
/// Packetize one access unit into wire packets (header + shard payload each).
pub fn packetize(
&mut self,
frame: &[u8],
pts_ns: u64,
user_flags: u32,
coder: &dyn ErasureCoder,
) -> Result<Vec<Vec<u8>>> {
let payload = self.shard_payload;
let frame_index = self.next_frame_index;
self.next_frame_index = self.next_frame_index.wrapping_add(1);
// At least one (zero-padded) data shard even for an empty frame.
let total_data = frame.len().div_ceil(payload).max(1);
let max_block = self.fec.max_data_per_block as usize;
let block_count = total_data.div_ceil(max_block).max(1);
let frame_bytes = frame.len() as u32;
// Defend the u16 wire fields against silent truncation. `Config::validate`
// already rejects configs that could reach these for valid frame sizes; this is
// the belt-and-suspenders for a frame larger than the negotiated maximum.
if payload > u16::MAX as usize {
return Err(PunktfunkError::InvalidArg("shard_payload exceeds u16"));
}
if block_count > u16::MAX as usize {
return Err(PunktfunkError::Unsupported(
"frame too large: block count exceeds u16",
));
}
let mut packets = Vec::new();
for b in 0..block_count {
let first = b * max_block;
let last = ((b + 1) * max_block).min(total_data);
let block_data_count = last - first;
// Build this block's data shards (each `payload` bytes, last zero-padded).
let mut data_shards: Vec<Vec<u8>> = Vec::with_capacity(block_data_count);
for s in first..last {
let start = s * payload;
let end = (start + payload).min(frame.len());
let mut shard = vec![0u8; payload];
if start < frame.len() {
shard[..end - start].copy_from_slice(&frame[start..end]);
}
data_shards.push(shard);
}
let recovery_count = self.fec.recovery_for(block_data_count);
let recovery = coder.encode(&data_shards, recovery_count)?;
let total_shards = block_data_count + recovery_count;
if total_shards > u16::MAX as usize {
return Err(PunktfunkError::Unsupported("block shard count exceeds u16"));
}
for shard_index in 0..total_shards {
let body: &[u8] = if shard_index < block_data_count {
&data_shards[shard_index]
} else {
&recovery[shard_index - block_data_count]
};
let seq = self.next_seq;
self.next_seq = self.next_seq.wrapping_add(1);
let mut flags = FLAG_PIC;
if b == 0 && shard_index == 0 {
flags |= FLAG_SOF;
}
if b + 1 == block_count && shard_index + 1 == total_shards {
flags |= FLAG_EOF;
}
let hdr = PacketHeader {
pts_ns,
frame_index,
stream_seq: seq,
frame_bytes,
user_flags,
block_index: b as u16,
block_count: block_count as u16,
data_shards: block_data_count as u16,
recovery_shards: recovery_count as u16,
shard_index: shard_index as u16,
shard_bytes: payload as u16,
magic: PUNKTFUNK_MAGIC,
version: self.version,
fec_scheme: coder.scheme() as u8,
flags,
};
let mut pkt = Vec::with_capacity(HEADER_LEN + body.len());
pkt.extend_from_slice(hdr.as_bytes());
pkt.extend_from_slice(body);
packets.push(pkt);
}
}
Ok(packets)
}
}
// ---------------------------------------------------------------------------
// Client side: reassembly + FEC recovery
// ---------------------------------------------------------------------------
struct BlockBuf {
data_shards: usize,
recovery_shards: usize,
shard_bytes: usize,
/// Length `data_shards + recovery_shards`; `Some` = received.
shards: Vec<Option<Vec<u8>>>,
received: usize,
done: bool,
}
struct FrameBuf {
frame_bytes: usize,
block_count: usize,
pts_ns: u64,
user_flags: u32,
blocks: HashMap<u16, BlockBuf>,
/// Reconstructed payload per completed block, ordered by block index.
block_data: BTreeMap<u16, Vec<u8>>,
}
/// Per-session bounds the reassembler enforces on every packet header *before*
/// allocating, so a hostile or corrupt header cannot drive unbounded memory use. All
/// derived from the negotiated [`Config`].
#[derive(Clone, Copy, Debug)]
pub struct ReassemblerLimits {
/// Expected shard payload length; every shard in the stream must match exactly.
pub shard_bytes: usize,
/// Max data shards per block (the negotiated `max_data_per_block`).
pub max_data_shards: usize,
/// Max total shards per block (data + recovery), capped by the FEC scheme ceiling.
pub max_total_shards: usize,
/// Max FEC blocks per frame.
pub max_blocks: usize,
/// Max accepted access-unit size.
pub max_frame_bytes: usize,
}
impl ReassemblerLimits {
pub fn from_config(c: &Config) -> Self {
let max_data = c.fec.max_data_per_block as usize;
let max_total =
(max_data + c.fec.recovery_for(max_data)).min(c.fec.scheme.max_total_shards());
let total_data = c.max_frame_bytes.div_ceil(c.shard_payload.max(1)).max(1);
ReassemblerLimits {
shard_bytes: c.shard_payload,
max_data_shards: max_data,
max_total_shards: max_total,
max_blocks: total_data.div_ceil(max_data).max(1),
max_frame_bytes: c.max_frame_bytes,
}
}
}
/// Buffers incoming shards, recovers lost ones via FEC, and emits whole access units.
/// Client-side only.
pub struct Reassembler {
limits: ReassemblerLimits,
frames: HashMap<u32, FrameBuf>,
/// Recently-emitted frames, so stray/late shards can't resurrect them. Pruned to
/// the reorder window alongside `frames`.
completed: HashSet<u32>,
newest_frame: Option<u32>,
}
impl Reassembler {
pub fn new(limits: ReassemblerLimits) -> Self {
Reassembler {
limits,
frames: HashMap::new(),
completed: HashSet::new(),
newest_frame: None,
}
}
/// Ingest one (already-decrypted) packet. Returns the access unit when its last
/// block completes, otherwise `None`.
pub fn push(
&mut self,
pkt: &[u8],
coder: &dyn ErasureCoder,
stats: &StatsCounters,
) -> Result<Option<Frame>> {
// On a lossy datagram link a malformed or non-video packet is dropped, never
// fatal: it must not abort `poll_frame`. Only a genuine FEC reconstruction
// failure propagates as an error.
if pkt.len() < HEADER_LEN {
StatsCounters::add(&stats.packets_dropped, 1);
return Ok(None);
}
let hdr = match PacketHeader::read_from_bytes(&pkt[..HEADER_LEN]) {
Ok(h) => h,
Err(_) => {
StatsCounters::add(&stats.packets_dropped, 1);
return Ok(None);
}
};
let lim = self.limits;
let shard_bytes = hdr.shard_bytes as usize;
let data_shards = hdr.data_shards as usize;
let recovery_shards = hdr.recovery_shards as usize;
let total = data_shards + recovery_shards;
let shard_index = hdr.shard_index as usize;
let block_count = hdr.block_count as usize;
let frame_bytes = hdr.frame_bytes as usize;
// Bound every attacker-controllable header field against the negotiated limits
// BEFORE allocating anything keyed on it — this is the firewall against a tiny
// datagram triggering a huge `vec![None; total]` / `Vec::with_capacity`.
let drop = |stats: &StatsCounters| {
StatsCounters::add(&stats.packets_dropped, 1);
};
if hdr.magic != PUNKTFUNK_MAGIC
|| shard_bytes != lim.shard_bytes
|| pkt.len() < HEADER_LEN + shard_bytes
|| data_shards == 0
|| data_shards > lim.max_data_shards
|| total == 0
|| total > lim.max_total_shards
|| shard_index >= total
|| block_count == 0
|| block_count > lim.max_blocks
|| hdr.block_index as usize >= block_count
|| frame_bytes > lim.max_frame_bytes
{
drop(stats);
return Ok(None);
}
let payload = pkt[HEADER_LEN..HEADER_LEN + shard_bytes].to_vec();
self.advance_window(hdr.frame_index, stats);
// Drop shards for frames we've already emitted (e.g. the recovery shards of a
// frame that completed early via the all-originals-present fast path) or that
// have fallen out of the reorder window.
if self.completed.contains(&hdr.frame_index) || self.is_stale(hdr.frame_index) {
drop(stats);
return Ok(None);
}
// First packet of a frame establishes its geometry; later packets must agree.
let frame = self
.frames
.entry(hdr.frame_index)
.or_insert_with(|| FrameBuf {
frame_bytes,
block_count,
pts_ns: hdr.pts_ns,
user_flags: hdr.user_flags,
blocks: HashMap::new(),
block_data: BTreeMap::new(),
});
if frame.block_count != block_count || frame.frame_bytes != frame_bytes {
drop(stats);
return Ok(None);
}
if frame.block_data.contains_key(&hdr.block_index) {
return Ok(None); // block already reconstructed; late/duplicate shard
}
// First packet of a block sizes its shard vector; later packets must match its
// (data, recovery, shard_bytes) geometry, so `shard_index` is always in bounds.
frame
.blocks
.entry(hdr.block_index)
.or_insert_with(|| BlockBuf {
data_shards,
recovery_shards,
shard_bytes,
shards: vec![None; total],
received: 0,
done: false,
});
let block = frame.blocks.get_mut(&hdr.block_index).unwrap();
if block.data_shards != data_shards
|| block.recovery_shards != recovery_shards
|| block.shard_bytes != shard_bytes
{
drop(stats);
return Ok(None);
}
if block.shards[shard_index].is_none() {
block.shards[shard_index] = Some(payload);
block.received += 1;
}
// Reconstruct as soon as we hold enough shards.
if !block.done && block.received >= block.data_shards {
let present_data = block.shards[..block.data_shards]
.iter()
.filter(|s| s.is_some())
.count();
let recovered =
coder.reconstruct(block.data_shards, block.recovery_shards, &mut block.shards)?;
block.done = true;
StatsCounters::add(
&stats.fec_recovered_shards,
(block.data_shards - present_data) as u64,
);
// Concatenate the block's data shards into its contiguous payload.
let mut block_payload = Vec::with_capacity(block.data_shards * block.shard_bytes);
for shard in &recovered {
block_payload.extend_from_slice(shard);
}
frame.block_data.insert(hdr.block_index, block_payload);
frame.blocks.remove(&hdr.block_index);
}
// Whole frame ready?
if frame.block_data.len() == frame.block_count {
let frame = self.frames.remove(&hdr.frame_index).unwrap();
self.completed.insert(hdr.frame_index);
// Reserve based on the bytes we actually hold, not the (already-bounded but
// still caller-supplied) frame_bytes, so a small frame can't over-reserve.
let actual: usize = frame.block_data.values().map(|b| b.len()).sum();
let mut data = Vec::with_capacity(actual);
for (_, block_payload) in frame.block_data.into_iter() {
data.extend_from_slice(&block_payload);
}
data.truncate(frame.frame_bytes); // trim trailing-shard zero padding
return Ok(Some(Frame {
data,
frame_index: hdr.frame_index,
pts_ns: frame.pts_ns,
flags: frame.user_flags,
}));
}
Ok(None)
}
/// Track the newest frame and prune stragglers that fell out of the reorder window
/// (counting them as dropped).
fn advance_window(&mut self, frame_index: u32, stats: &StatsCounters) {
let newest = match self.newest_frame {
// `frame_index` is newer iff it's within the forward half of the index space.
Some(n) if frame_index.wrapping_sub(n) > u32::MAX / 2 => n,
_ => frame_index,
};
self.newest_frame = Some(newest);
let before = self.frames.len();
self.frames
.retain(|&idx, _| newest.wrapping_sub(idx) <= REORDER_WINDOW);
let pruned = before - self.frames.len();
if pruned > 0 {
StatsCounters::add(&stats.frames_dropped, pruned as u64);
}
self.completed
.retain(|&idx| newest.wrapping_sub(idx) <= REORDER_WINDOW);
}
/// True if `frame_index` lies behind the newest frame by more than the reorder
/// window (so its shards arrive too late to be useful).
fn is_stale(&self, frame_index: u32) -> bool {
match self.newest_frame {
Some(n) => {
let behind = n.wrapping_sub(frame_index);
behind > REORDER_WINDOW && behind <= u32::MAX / 2
}
None => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::FecScheme;
use crate::fec::coder_for;
fn limits() -> ReassemblerLimits {
ReassemblerLimits {
shard_bytes: 16,
max_data_shards: 8,
max_total_shards: 12,
max_blocks: 4,
max_frame_bytes: 4096,
}
}
fn base_header() -> PacketHeader {
PacketHeader {
pts_ns: 0,
frame_index: 0,
stream_seq: 0,
frame_bytes: 16,
user_flags: 0,
block_index: 0,
block_count: 1,
data_shards: 1,
recovery_shards: 0,
shard_index: 0,
shard_bytes: 16,
magic: PUNKTFUNK_MAGIC,
version: 1,
fec_scheme: 0,
flags: FLAG_PIC,
}
}
fn packet(h: PacketHeader) -> Vec<u8> {
let mut p = Vec::new();
p.extend_from_slice(h.as_bytes());
p.extend_from_slice(&vec![0xAB; h.shard_bytes as usize]);
p
}
/// A header advertising 65535+65535 shards must be dropped, not allocate gigabytes.
#[test]
fn rejects_oversized_shard_counts() {
let mut r = Reassembler::new(limits());
let coder = coder_for(FecScheme::Gf8);
let stats = StatsCounters::default();
let mut h = base_header();
h.data_shards = 65535;
h.recovery_shards = 65535;
assert!(r
.push(&packet(h), coder.as_ref(), &stats)
.unwrap()
.is_none());
assert_eq!(stats.snapshot().packets_dropped, 1);
}
/// A second packet for a block whose geometry differs from the first must be dropped
/// — never index past the block's allocated shard vector (the old OOB panic).
#[test]
fn rejects_inconsistent_block_geometry_without_panicking() {
let mut r = Reassembler::new(limits());
let coder = coder_for(FecScheme::Gf8);
let stats = StatsCounters::default();
let mut h1 = base_header();
h1.data_shards = 4;
h1.recovery_shards = 2; // block sized to 6 slots
h1.frame_bytes = 64;
assert!(r
.push(&packet(h1), coder.as_ref(), &stats)
.unwrap()
.is_none());
// Same block, different geometry, shard_index valid for ITS total (8) but past
// the established block's 6 slots.
let mut h2 = base_header();
h2.data_shards = 6;
h2.recovery_shards = 2;
h2.shard_index = 7;
h2.frame_bytes = 64;
assert!(r
.push(&packet(h2), coder.as_ref(), &stats)
.unwrap()
.is_none());
assert_eq!(stats.snapshot().packets_dropped, 1);
}
#[test]
fn rejects_wrong_shard_bytes_and_oversized_frame() {
let coder = coder_for(FecScheme::Gf8);
let mut r = Reassembler::new(limits());
let stats = StatsCounters::default();
let mut h = base_header();
h.shard_bytes = 8; // != negotiated 16
assert!(r
.push(&packet(h), coder.as_ref(), &stats)
.unwrap()
.is_none());
assert_eq!(stats.snapshot().packets_dropped, 1);
let mut r = Reassembler::new(limits());
let stats = StatsCounters::default();
let mut h = base_header();
h.frame_bytes = 1_000_000; // > max_frame_bytes
assert!(r
.push(&packet(h), coder.as_ref(), &stats)
.unwrap()
.is_none());
assert_eq!(stats.snapshot().packets_dropped, 1);
}
}
+527
View File
@@ -0,0 +1,527 @@
//! `punktfunk/1` — the native control plane (M3), gated behind the `quic` feature.
//!
//! GameStream is punktfunk's compatibility layer; this is the start of its own protocol. A QUIC
//! connection (quinn, tokio — control plane only, never the per-frame path) carries a
//! length-prefixed binary handshake on one bidirectional stream:
//!
//! ```text
//! client → host Hello { abi_version }
//! host → client Welcome { abi_version, session: full data-plane Config + mode + UDP port }
//! client → host Start { client_udp_port }
//! ```
//!
//! after which both sides bring up a [`crate::session::Session`] over a plain
//! [`UdpTransport`](crate::transport::udp) (native threads, no async) and the host streams.
//! The Welcome carries everything the M1 core negotiates — FEC scheme (including GF(2¹⁶)
//! Leopard, which GameStream can't express), shard sizing, crypto key/salt — so the data
//! plane is exactly the hardened M1 `Session`.
//!
//! Transport security: the host presents a long-lived self-signed certificate
//! ([`endpoint::server_with_identity`]) and the client pins its SHA-256 fingerprint
//! ([`endpoint::client_pinned`]; no pin = trust-on-first-use, with the observed fingerprint
//! reported back for persisting). The data plane adds AES-GCM on top.
//! All integers little-endian; every message is `u16 length || payload`.
use crate::config::{Config, FecConfig, FecScheme, Mode, ProtocolPhase, Role};
use crate::error::{PunktfunkError, Result};
/// Protocol magic + version, first bytes of every message payload.
pub const MAGIC: &[u8; 4] = b"PKF1";
/// `client → host`: open the session, requesting a display mode (the host creates its
/// virtual output at exactly this size/refresh — native resolution end to end).
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Hello {
pub abi_version: u32,
pub mode: Mode,
}
/// `host → client`: the complete session offer.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Welcome {
pub abi_version: u32,
/// Host UDP port for the data plane.
pub udp_port: u16,
pub mode: Mode,
pub fec: FecConfig,
pub shard_payload: u16,
pub encrypt: bool,
pub key: [u8; 16],
pub salt: [u8; 4],
/// Seed/testing: how many frames the host will send (0 = unbounded).
pub frames: u32,
}
/// `client → host`: data plane is bound, begin streaming.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Start {
pub client_udp_port: u16,
}
impl Hello {
pub fn encode(&self) -> Vec<u8> {
let mut b = Vec::with_capacity(20);
b.extend_from_slice(MAGIC);
b.extend_from_slice(&self.abi_version.to_le_bytes());
b.extend_from_slice(&self.mode.width.to_le_bytes());
b.extend_from_slice(&self.mode.height.to_le_bytes());
b.extend_from_slice(&self.mode.refresh_hz.to_le_bytes());
b
}
pub fn decode(b: &[u8]) -> Result<Hello> {
if b.len() < 20 || &b[0..4] != MAGIC {
return Err(PunktfunkError::InvalidArg("bad Hello"));
}
let u32at = |o: usize| u32::from_le_bytes([b[o], b[o + 1], b[o + 2], b[o + 3]]);
Ok(Hello {
abi_version: u32at(4),
mode: Mode {
width: u32at(8),
height: u32at(12),
refresh_hz: u32at(16),
},
})
}
}
impl Welcome {
pub fn encode(&self) -> Vec<u8> {
let mut b = Vec::with_capacity(64);
b.extend_from_slice(MAGIC);
b.extend_from_slice(&self.abi_version.to_le_bytes());
b.extend_from_slice(&self.udp_port.to_le_bytes());
b.extend_from_slice(&self.mode.width.to_le_bytes());
b.extend_from_slice(&self.mode.height.to_le_bytes());
b.extend_from_slice(&self.mode.refresh_hz.to_le_bytes());
b.push(match self.fec.scheme {
FecScheme::Gf8 => 0,
FecScheme::Gf16 => 1,
});
b.push(self.fec.fec_percent);
b.extend_from_slice(&self.fec.max_data_per_block.to_le_bytes());
b.extend_from_slice(&self.shard_payload.to_le_bytes());
b.push(self.encrypt as u8);
b.extend_from_slice(&self.key);
b.extend_from_slice(&self.salt);
b.extend_from_slice(&self.frames.to_le_bytes());
b
}
pub fn decode(b: &[u8]) -> Result<Welcome> {
// Layout (LE): magic[0..4] abi[4..8] port[8..10] w[10..14] h[14..18] hz[18..22]
// scheme[22] pct[23] max_data[24..26] shard[26..28] encrypt[28] key[29..45]
// salt[45..49] frames[49..53].
if b.len() < 53 || &b[0..4] != MAGIC {
return Err(PunktfunkError::InvalidArg("bad Welcome"));
}
let u32at = |o: usize| u32::from_le_bytes([b[o], b[o + 1], b[o + 2], b[o + 3]]);
let u16at = |o: usize| u16::from_le_bytes([b[o], b[o + 1]]);
let mut key = [0u8; 16];
key.copy_from_slice(&b[29..45]);
let mut salt = [0u8; 4];
salt.copy_from_slice(&b[45..49]);
Ok(Welcome {
abi_version: u32at(4),
udp_port: u16at(8),
mode: Mode {
width: u32at(10),
height: u32at(14),
refresh_hz: u32at(18),
},
fec: FecConfig {
scheme: if b[22] == 1 {
FecScheme::Gf16
} else {
FecScheme::Gf8
},
fec_percent: b[23],
max_data_per_block: u16at(24),
},
shard_payload: u16at(26),
encrypt: b[28] != 0,
key,
salt,
frames: u32at(49),
})
}
/// Build the data-plane [`Config`] this offer describes (for `role`).
pub fn session_config(&self, role: Role) -> Config {
let mut c = Config::p1_defaults(role);
c.phase = ProtocolPhase::P1GameStream; // wire phase id pending the P2 packet rev
c.fec = self.fec;
c.shard_payload = self.shard_payload as usize;
c.encrypt = self.encrypt;
c.key = self.key;
c.salt = self.salt;
c
}
}
impl Start {
pub fn encode(&self) -> Vec<u8> {
let mut b = Vec::with_capacity(6);
b.extend_from_slice(MAGIC);
b.extend_from_slice(&self.client_udp_port.to_le_bytes());
b
}
pub fn decode(b: &[u8]) -> Result<Start> {
if b.len() < 6 || &b[0..4] != MAGIC {
return Err(PunktfunkError::InvalidArg("bad Start"));
}
Ok(Start {
client_udp_port: u16::from_le_bytes([b[4], b[5]]),
})
}
}
/// Frame a message for the control stream: `u16 LE length || payload`.
pub fn frame(payload: &[u8]) -> Vec<u8> {
let mut b = Vec::with_capacity(2 + payload.len());
b.extend_from_slice(&(payload.len() as u16).to_le_bytes());
b.extend_from_slice(payload);
b
}
/// Datagram wire tags. Video rides UDP; everything low-rate rides QUIC datagrams,
/// demultiplexed by the first byte: input = [`crate::input::INPUT_MAGIC`] (0xC8),
/// audio = [`AUDIO_MAGIC`], rumble = [`RUMBLE_MAGIC`].
pub const AUDIO_MAGIC: u8 = 0xC9;
pub const RUMBLE_MAGIC: u8 = 0xCA;
/// Audio datagram, host → client: `[0xC9][u32 seq LE][u64 pts_ns LE][opus payload]`.
/// One Opus frame per datagram (5 ms — well under any MTU); QUIC already encrypts.
pub fn encode_audio_datagram(seq: u32, pts_ns: u64, opus: &[u8]) -> Vec<u8> {
let mut b = Vec::with_capacity(13 + opus.len());
b.push(AUDIO_MAGIC);
b.extend_from_slice(&seq.to_le_bytes());
b.extend_from_slice(&pts_ns.to_le_bytes());
b.extend_from_slice(opus);
b
}
/// Parse an audio datagram → `(seq, pts_ns, opus payload)`. `None` on bad tag/length.
pub fn decode_audio_datagram(b: &[u8]) -> Option<(u32, u64, &[u8])> {
if b.len() < 13 || b[0] != AUDIO_MAGIC {
return None;
}
let seq = u32::from_le_bytes(b[1..5].try_into().unwrap());
let pts_ns = u64::from_le_bytes(b[5..13].try_into().unwrap());
Some((seq, pts_ns, &b[13..]))
}
/// Rumble datagram, host → client: `[0xCA][u16 pad LE][u16 low LE][u16 high LE]`.
/// Force-feedback state for pad `pad` (0xFFFF amplitudes, 0/0 = stop).
pub fn encode_rumble_datagram(pad: u16, low: u16, high: u16) -> [u8; 7] {
let mut b = [0u8; 7];
b[0] = RUMBLE_MAGIC;
b[1..3].copy_from_slice(&pad.to_le_bytes());
b[3..5].copy_from_slice(&low.to_le_bytes());
b[5..7].copy_from_slice(&high.to_le_bytes());
b
}
/// Parse a rumble datagram → `(pad, low, high)`. `None` on bad tag/length.
pub fn decode_rumble_datagram(b: &[u8]) -> Option<(u16, u16, u16)> {
if b.len() < 7 || b[0] != RUMBLE_MAGIC {
return None;
}
let u16at = |o: usize| u16::from_le_bytes([b[o], b[o + 1]]);
Some((u16at(1), u16at(3), u16at(5)))
}
/// Async framed-message IO over a quinn stream (`u16 LE length || payload`).
pub mod io {
/// Read one framed message (bounded at 64 KiB — control messages are tiny).
pub async fn read_msg(recv: &mut quinn::RecvStream) -> std::io::Result<Vec<u8>> {
let mut len = [0u8; 2];
recv.read_exact(&mut len)
.await
.map_err(std::io::Error::other)?;
let n = u16::from_le_bytes(len) as usize;
let mut buf = vec![0u8; n];
recv.read_exact(&mut buf)
.await
.map_err(std::io::Error::other)?;
Ok(buf)
}
/// Write one framed message.
pub async fn write_msg(send: &mut quinn::SendStream, payload: &[u8]) -> std::io::Result<()> {
send.write_all(&super::frame(payload))
.await
.map_err(std::io::Error::other)
}
}
/// quinn endpoint constructors. Host: self-signed identity (fresh, or persisted PEMs via
/// [`endpoint::server_with_identity`]). Client: fingerprint pinning / TOFU via
/// [`endpoint::client_pinned`] ([`endpoint::client_insecure`] is the no-pin special case).
pub mod endpoint {
use std::sync::{Arc, Mutex};
/// Server endpoint with a fresh self-signed certificate (tests/dev — production hosts
/// persist an identity and use [`server_with_identity`] so clients can pin it).
pub fn server(addr: std::net::SocketAddr) -> anyhow_result::Result<quinn::Endpoint> {
let cert = rcgen::generate_simple_self_signed(vec!["punktfunk".into()])
.map_err(|e| anyhow_result::Error::msg(format!("self-signed cert: {e}")))?;
let cert_der = rustls::pki_types::CertificateDer::from(cert.cert);
let key_der = rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
server_from_der(cert_der, key_der.into(), addr)
}
/// Server endpoint from a persisted PEM identity (certificate + PKCS#8 private key) —
/// the host's long-lived self-signed cert, so the fingerprint clients pin is stable
/// across restarts.
pub fn server_with_identity(
addr: std::net::SocketAddr,
cert_pem: &str,
key_pem: &str,
) -> anyhow_result::Result<quinn::Endpoint> {
use rustls::pki_types::pem::PemObject;
let cert_der = rustls::pki_types::CertificateDer::from_pem_slice(cert_pem.as_bytes())
.map_err(|e| anyhow_result::Error::msg(format!("cert pem: {e}")))?;
let key_der = rustls::pki_types::PrivateKeyDer::from_pem_slice(key_pem.as_bytes())
.map_err(|e| anyhow_result::Error::msg(format!("key pem: {e}")))?;
server_from_der(cert_der, key_der, addr)
}
fn server_from_der(
cert_der: rustls::pki_types::CertificateDer<'static>,
key_der: rustls::pki_types::PrivateKeyDer<'static>,
addr: std::net::SocketAddr,
) -> anyhow_result::Result<quinn::Endpoint> {
let server_config = quinn::ServerConfig::with_single_cert(vec![cert_der], key_der)
.map_err(|e| anyhow_result::Error::msg(format!("server config: {e}")))?;
Ok(quinn::Endpoint::server(server_config, addr)?)
}
/// SHA-256 of a certificate's DER encoding — the fingerprint clients pin.
pub fn cert_fingerprint(cert_der: &[u8]) -> [u8; 32] {
use sha2::Digest;
sha2::Sha256::digest(cert_der).into()
}
/// Fingerprint of a PEM-encoded certificate (what a host logs/shows for pairing UX —
/// must match what the client's verifier computes from the DER on the wire).
pub fn fingerprint_of_pem(cert_pem: &str) -> anyhow_result::Result<[u8; 32]> {
use rustls::pki_types::pem::PemObject;
let der = rustls::pki_types::CertificateDer::from_pem_slice(cert_pem.as_bytes())
.map_err(|e| anyhow_result::Error::msg(format!("cert pem: {e}")))?;
Ok(cert_fingerprint(der.as_ref()))
}
/// Client endpoint that skips certificate verification (TOFU bootstrap — read the
/// observed fingerprint off the slot and pin it on the next connect).
pub fn client_insecure() -> anyhow_result::Result<quinn::Endpoint> {
client_pinned(None).0
}
/// What [`client_pinned`] returns: the endpoint plus the slot the verifier writes the
/// observed host fingerprint into during the handshake.
pub type PinnedClient = (
anyhow_result::Result<quinn::Endpoint>,
Arc<Mutex<Option<[u8; 32]>>>,
);
/// Client endpoint that verifies the host by certificate fingerprint.
///
/// `pin = Some(sha256)` rejects any host whose leaf cert doesn't hash to `sha256`;
/// `None` accepts any (trust-on-first-use). Either way the observed fingerprint is
/// written to the returned slot during the handshake, so a TOFU caller can persist it.
pub fn client_pinned(pin: Option<[u8; 32]>) -> PinnedClient {
let observed = Arc::new(Mutex::new(None));
let ep = (|| {
let _ = rustls::crypto::ring::default_provider().install_default();
let rustls_cfg = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(PinVerify {
pin,
observed: observed.clone(),
}))
.with_no_client_auth();
let quic_cfg = quinn::crypto::rustls::QuicClientConfig::try_from(rustls_cfg)
.map_err(|e| anyhow_result::Error::msg(format!("quic client config: {e}")))?;
let mut ep = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())?;
ep.set_default_client_config(quinn::ClientConfig::new(Arc::new(quic_cfg)));
Ok(ep)
})();
(ep, observed)
}
/// Minimal error plumbing without pulling anyhow into punktfunk-core's public API.
pub mod anyhow_result {
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub struct Error(String);
impl Error {
pub fn msg(s: String) -> Self {
Error(s)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for Error {}
impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error(e.to_string())
}
}
}
/// Fingerprint-pinning verifier: trust is the SHA-256 of the host's (self-signed) leaf
/// cert, not a CA chain. With no pin it accepts any cert (TOFU) but still records what
/// it saw, so the embedder can persist the fingerprint and pin it from then on.
#[derive(Debug)]
struct PinVerify {
pin: Option<[u8; 32]>,
observed: Arc<Mutex<Option<[u8; 32]>>>,
}
impl rustls::client::danger::ServerCertVerifier for PinVerify {
fn verify_server_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error>
{
let fp = cert_fingerprint(end_entity.as_ref());
*self.observed.lock().unwrap() = Some(fp);
if let Some(expected) = self.pin {
if fp != expected {
return Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::ApplicationVerificationFailure,
));
}
}
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
// The handshake signatures MUST be verified for real even though we pin the cert:
// CertificateVerify is what proves the peer *holds the pinned cert's private key* —
// skip it and an active MITM can replay the host's (public) certificate, match the
// pin, and complete the handshake with its own key.
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn welcome_roundtrip() {
let w = Welcome {
abi_version: 1,
udp_port: 9999,
mode: Mode {
width: 2560,
height: 1440,
refresh_hz: 240,
},
fec: FecConfig {
scheme: FecScheme::Gf16,
fec_percent: 20,
max_data_per_block: 4096,
},
shard_payload: 1200,
encrypt: true,
key: [7u8; 16],
salt: [1, 2, 3, 4],
frames: 600,
};
assert_eq!(Welcome::decode(&w.encode()).unwrap(), w);
}
#[test]
fn hello_start_roundtrip() {
let h = Hello {
abi_version: 1,
mode: Mode {
width: 1280,
height: 720,
refresh_hz: 120,
},
};
assert_eq!(Hello::decode(&h.encode()).unwrap(), h);
let s = Start {
client_udp_port: 1234,
};
assert_eq!(Start::decode(&s.encode()).unwrap(), s);
}
#[test]
fn audio_datagram_roundtrip() {
let opus = [0x42u8; 97];
let d = encode_audio_datagram(7, 1_000_000_123, &opus);
assert_eq!(d[0], AUDIO_MAGIC);
let (seq, pts, payload) = decode_audio_datagram(&d).unwrap();
assert_eq!((seq, pts), (7, 1_000_000_123));
assert_eq!(payload, opus);
assert!(decode_audio_datagram(&d[..12]).is_none()); // truncated header
assert!(decode_audio_datagram(&[0u8; 13]).is_none()); // bad magic
// Empty payload is legal (DTX) — header-only datagram.
let header_only = encode_audio_datagram(0, 0, &[]);
let (_, _, empty) = decode_audio_datagram(&header_only).unwrap();
assert!(empty.is_empty());
}
#[test]
fn rumble_datagram_roundtrip() {
let d = encode_rumble_datagram(1, 0x1234, 0xFFFF);
assert_eq!(d[0], RUMBLE_MAGIC);
assert_eq!(decode_rumble_datagram(&d), Some((1, 0x1234, 0xFFFF)));
assert!(decode_rumble_datagram(&d[..6]).is_none());
}
#[test]
fn fingerprint_is_sha256_of_der() {
// Stable across calls, distinct for distinct certs.
let a = endpoint::cert_fingerprint(b"cert-a");
assert_eq!(a, endpoint::cert_fingerprint(b"cert-a"));
assert_ne!(a, endpoint::cert_fingerprint(b"cert-b"));
}
}
+198
View File
@@ -0,0 +1,198 @@
//! Session lifecycle and the two hot-path state machines.
//!
//! - **Host** ([`Session::submit_frame`]): encoded access unit → FEC + packetize →
//! optional AES-GCM seal → transport send.
//! - **Client** ([`Session::poll_frame`]): transport recv → optional open → reorder +
//! FEC recover + reassemble → whole access unit.
//!
//! Both directions also carry input: a client [`Session::send_input`]s events; the host
//! drains them with [`Session::poll_input`].
use crate::config::{Config, Role};
use crate::crypto::SessionCrypto;
use crate::error::{PunktfunkError, Result};
use crate::fec::{coder_for, ErasureCoder};
use crate::input::InputEvent;
use crate::packet::{Packetizer, Reassembler, ReassemblerLimits};
use crate::stats::{Stats, StatsCounters};
use crate::transport::Transport;
/// A reassembled, FEC-recovered access unit, ready to hand to the platform decoder.
pub struct Frame {
pub data: Vec<u8>,
pub frame_index: u32,
pub pts_ns: u64,
pub flags: u32,
}
/// One end of a stream. Constructed for a single [`Role`]; calling the other role's
/// methods returns [`PunktfunkError::InvalidArg`].
///
/// Note: the AEAD layer authenticates each datagram but does **not** provide anti-replay.
/// Video replays are largely absorbed by the reassembler's per-frame dedup, but replayed
/// input events are not yet filtered. A sliding-window replay filter keyed on the
/// authenticated sequence belongs with the pairing/handshake layer (M2); until then,
/// rely on the LAN/VPN transport assumption (plan §1).
pub struct Session {
config: Config,
coder: Box<dyn ErasureCoder>,
crypto: Option<SessionCrypto>,
transport: Box<dyn Transport>,
packetizer: Packetizer,
reassembler: Reassembler,
stats: StatsCounters,
/// Monotonic wire sequence, also the AES-GCM nonce counter.
next_seq: u64,
}
impl Session {
pub fn new(config: Config, transport: Box<dyn Transport>) -> Result<Session> {
config.validate()?;
let coder = coder_for(config.fec.scheme);
let crypto = config
.encrypt
.then(|| SessionCrypto::new(&config.key, config.salt, config.role));
let packetizer = Packetizer::new(&config);
let reassembler = Reassembler::new(ReassemblerLimits::from_config(&config));
Ok(Session {
coder,
crypto,
transport,
packetizer,
reassembler,
stats: StatsCounters::default(),
next_seq: 0,
config,
})
}
pub fn role(&self) -> Role {
self.config.role
}
pub fn stats(&self) -> Stats {
self.stats.snapshot()
}
/// Wrap a packet for the wire: when encrypting, prepend the 8-byte big-endian
/// sequence (the receiver derives the GCM nonce from it) then the ciphertext.
fn seal_for_wire(&mut self, packet: &[u8]) -> Result<Vec<u8>> {
let seq = self.next_seq;
self.next_seq = self.next_seq.wrapping_add(1);
match &self.crypto {
Some(c) => {
let ct = c.seal(seq, packet)?;
let mut wire = Vec::with_capacity(8 + ct.len());
wire.extend_from_slice(&seq.to_be_bytes());
wire.extend_from_slice(&ct);
Ok(wire)
}
None => Ok(packet.to_vec()),
}
}
/// Unwrap a wire datagram back into a plaintext packet.
fn open_from_wire(&self, wire: &[u8]) -> Result<Vec<u8>> {
match &self.crypto {
Some(c) => {
if wire.len() < 8 {
return Err(PunktfunkError::BadPacket);
}
let seq = u64::from_be_bytes(wire[..8].try_into().unwrap());
c.open(seq, &wire[8..])
}
None => Ok(wire.to_vec()),
}
}
// -- Host path --------------------------------------------------------
/// Host: FEC-protect, packetize, seal, and send one encoded access unit.
pub fn submit_frame(&mut self, data: &[u8], pts_ns: u64, user_flags: u32) -> Result<()> {
if self.config.role != Role::Host {
return Err(PunktfunkError::InvalidArg(
"submit_frame called on a client session",
));
}
let packets = self
.packetizer
.packetize(data, pts_ns, user_flags, self.coder.as_ref())?;
StatsCounters::add(&self.stats.frames_submitted, 1);
for pkt in packets {
let wire = self.seal_for_wire(&pkt)?;
StatsCounters::add(&self.stats.packets_sent, 1);
StatsCounters::add(&self.stats.bytes_sent, wire.len() as u64);
self.transport.send(&wire)?;
}
Ok(())
}
/// Host: drain one pending input event from the client, if any.
pub fn poll_input(&mut self) -> Result<Option<InputEvent>> {
if self.config.role != Role::Host {
return Err(PunktfunkError::InvalidArg(
"poll_input called on a client session",
));
}
while let Some(wire) = self.transport.recv()? {
let pkt = match self.open_from_wire(&wire) {
Ok(p) => p,
Err(_) => continue, // drop undecryptable noise
};
StatsCounters::add(&self.stats.packets_received, 1);
if let Some(ev) = InputEvent::decode(&pkt) {
return Ok(Some(ev));
}
// Not an input datagram (e.g. stray video) — ignore and keep draining.
}
Ok(None)
}
// -- Client path ------------------------------------------------------
/// Client: drain the transport until a whole access unit is recovered, or no more
/// packets are pending ([`PunktfunkError::NoFrame`]).
pub fn poll_frame(&mut self) -> Result<Frame> {
if self.config.role != Role::Client {
return Err(PunktfunkError::InvalidArg(
"poll_frame called on a host session",
));
}
loop {
let wire = match self.transport.recv()? {
Some(w) => w,
None => return Err(PunktfunkError::NoFrame),
};
let pkt = match self.open_from_wire(&wire) {
Ok(p) => p,
Err(_) => continue,
};
StatsCounters::add(&self.stats.packets_received, 1);
StatsCounters::add(&self.stats.bytes_received, pkt.len() as u64);
// The reassembler validates the packet via its parsed header (`magic`),
// ignoring anything that isn't a well-formed video packet.
if let Some(frame) = self
.reassembler
.push(&pkt, self.coder.as_ref(), &self.stats)?
{
StatsCounters::add(&self.stats.frames_completed, 1);
return Ok(frame);
}
}
}
/// Client: serialize and send one input event to the host.
pub fn send_input(&mut self, event: &InputEvent) -> Result<()> {
if self.config.role != Role::Client {
return Err(PunktfunkError::InvalidArg(
"send_input called on a host session",
));
}
let pkt = event.encode();
let wire = self.seal_for_wire(&pkt)?;
StatsCounters::add(&self.stats.packets_sent, 1);
StatsCounters::add(&self.stats.bytes_sent, wire.len() as u64);
self.transport.send(&wire)?;
Ok(())
}
}
+55
View File
@@ -0,0 +1,55 @@
//! Live counters for the frame-pacing / quality logic and the web UI.
use std::sync::atomic::{AtomicU64, Ordering};
/// Immutable snapshot, copied across the C ABI as `PunktfunkStats`.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct Stats {
pub frames_submitted: u64,
pub frames_completed: u64,
pub frames_dropped: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub packets_dropped: u64,
pub fec_recovered_shards: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
}
/// Atomic accumulators owned by a [`Session`](crate::session::Session). Snapshot to
/// [`Stats`] for readers. `Relaxed` ordering is fine: these are monotonic counters
/// read for display, never used to synchronize other memory.
#[derive(Default)]
pub struct StatsCounters {
pub frames_submitted: AtomicU64,
pub frames_completed: AtomicU64,
pub frames_dropped: AtomicU64,
pub packets_sent: AtomicU64,
pub packets_received: AtomicU64,
pub packets_dropped: AtomicU64,
pub fec_recovered_shards: AtomicU64,
pub bytes_sent: AtomicU64,
pub bytes_received: AtomicU64,
}
impl StatsCounters {
#[inline]
pub fn add(counter: &AtomicU64, n: u64) {
counter.fetch_add(n, Ordering::Relaxed);
}
pub fn snapshot(&self) -> Stats {
let l = Ordering::Relaxed;
Stats {
frames_submitted: self.frames_submitted.load(l),
frames_completed: self.frames_completed.load(l),
frames_dropped: self.frames_dropped.load(l),
packets_sent: self.packets_sent.load(l),
packets_received: self.packets_received.load(l),
packets_dropped: self.packets_dropped.load(l),
fec_recovered_shards: self.fec_recovered_shards.load(l),
bytes_sent: self.bytes_sent.load(l),
bytes_received: self.bytes_received.load(l),
}
}
}
@@ -0,0 +1,74 @@
//! In-process transport for unit tests and the C ABI harness. Two cross-wired
//! [`LoopbackTransport`]s form a host↔client link, with optional deterministic loss so
//! tests can exercise FEC recovery without a real network.
use super::Transport;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
/// One direction of the link.
struct Channel {
queue: Mutex<VecDeque<Vec<u8>>>,
/// Drop one of every `drop_period` packets (0 = lossless).
drop_period: u32,
sent: AtomicU64,
dropped: AtomicU64,
}
impl Channel {
fn new(drop_period: u32) -> Arc<Channel> {
Arc::new(Channel {
queue: Mutex::new(VecDeque::new()),
drop_period,
sent: AtomicU64::new(0),
dropped: AtomicU64::new(0),
})
}
}
/// Sends on `tx`, receives on `rx`. Created in cross-wired pairs by [`loopback_pair`].
pub struct LoopbackTransport {
tx: Arc<Channel>,
rx: Arc<Channel>,
}
impl LoopbackTransport {
/// Number of packets this transport's send side has deliberately dropped.
pub fn dropped(&self) -> u64 {
self.tx.dropped.load(Ordering::Relaxed)
}
}
/// Create a connected `(host, client)` pair. `host_drop_period` injects loss on the
/// host→client (video) path; `client_drop_period` on the reverse (input) path.
pub fn loopback_pair(
host_drop_period: u32,
client_drop_period: u32,
) -> (LoopbackTransport, LoopbackTransport) {
let h2c = Channel::new(host_drop_period);
let c2h = Channel::new(client_drop_period);
let host = LoopbackTransport {
tx: h2c.clone(),
rx: c2h.clone(),
};
let client = LoopbackTransport { tx: c2h, rx: h2c };
(host, client)
}
impl Transport for LoopbackTransport {
fn send(&self, packet: &[u8]) -> std::io::Result<()> {
let n = self.tx.sent.fetch_add(1, Ordering::Relaxed);
if self.tx.drop_period != 0 && (n % self.tx.drop_period as u64) == 0 {
// Deterministically drop in flight (the 1st of each `drop_period` group).
self.tx.dropped.fetch_add(1, Ordering::Relaxed);
return Ok(());
}
self.tx.queue.lock().unwrap().push_back(packet.to_vec());
Ok(())
}
fn recv(&self) -> std::io::Result<Option<Vec<u8>>> {
Ok(self.rx.queue.lock().unwrap().pop_front())
}
}
@@ -0,0 +1,15 @@
//! Pluggable packet I/O. The hot path calls [`Transport::send`] / [`Transport::recv`]
//! directly — no async runtime is involved.
mod loopback;
mod udp;
pub use loopback::{loopback_pair, LoopbackTransport};
pub use udp::UdpTransport;
/// A datagram transport. `recv` is non-blocking: it returns `Ok(None)` when no packet
/// is currently available, so the caller (decode/present thread) never blocks here.
pub trait Transport: Send + Sync {
fn send(&self, packet: &[u8]) -> std::io::Result<()>;
fn recv(&self) -> std::io::Result<Option<Vec<u8>>>;
}
@@ -0,0 +1,52 @@
//! Real UDP datagram transport — native sockets, no async runtime.
//!
//! M1 uses one `recv` syscall per packet; the latency budget (§7) calls for
//! `sendmmsg`/UDP-GSO batching to cut syscalls, which is a P2 optimization layered on
//! this same [`Transport`] seam.
use super::Transport;
use crate::packet::MAX_DATAGRAM_BYTES;
use std::net::UdpSocket;
/// Receive buffer size. `Config::validate` bounds `shard_payload` so a well-formed
/// datagram (header + shard + crypto overhead) always fits in [`MAX_DATAGRAM_BYTES`];
/// the `+ 1` byte lets us detect an oversized datagram (a full read) instead of
/// silently truncating it.
const RECV_BUF: usize = MAX_DATAGRAM_BYTES + 1;
pub struct UdpTransport {
socket: UdpSocket,
}
impl UdpTransport {
/// Bind `local` and `connect` to `peer`, so `send`/`recv` need no address and the
/// kernel filters to this peer. Non-blocking, matching the [`Transport`] contract.
pub fn connect(local: &str, peer: &str) -> std::io::Result<Self> {
let socket = UdpSocket::bind(local)?;
socket.connect(peer)?;
socket.set_nonblocking(true)?;
Ok(UdpTransport { socket })
}
}
impl Transport for UdpTransport {
fn send(&self, packet: &[u8]) -> std::io::Result<()> {
self.socket.send(packet)?;
Ok(())
}
fn recv(&self) -> std::io::Result<Option<Vec<u8>>> {
let mut buf = vec![0u8; RECV_BUF];
match self.socket.recv(&mut buf) {
// A read that fills the whole buffer means the datagram was larger than any
// valid packet — drop it rather than hand a truncated, corrupt packet up.
Ok(n) if n >= RECV_BUF => Ok(None),
Ok(n) => {
buf.truncate(n);
Ok(Some(buf))
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
Err(e) => Err(e),
}
}
}
+109
View File
@@ -0,0 +1,109 @@
/*
* punktfunk-core C ABI harness — M1 acceptance.
*
* Proves the core links from C and round-trips encoded access units through the full
* packetize -> FEC -> in-process loopback (with deterministic packet loss) -> FEC
* recover -> reassemble path, recovering every byte exactly.
*
* Build/run: see tests/c/run.sh (also driven by `cargo test --test c_abi`).
*/
#include "punktfunk_core.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
static PunktfunkConfig make_config(uint32_t role, uint32_t drop_period) {
PunktfunkConfig c;
memset(&c, 0, sizeof(c));
c.struct_size = (uint32_t)sizeof(PunktfunkConfig);
c.role = role; /* 0 = host, 1 = client */
c.phase = 1; /* P1, GameStream-compatible */
c.fec_scheme = 0; /* GF(2^8) */
c.fec_percent = 25;
c.max_data_per_block = 64;
c.shard_payload = 1024;
c.max_frame_bytes = 8 * 1024 * 1024;
c.encrypt = 0;
c.loopback_drop_period = drop_period;
return c;
}
int main(void) {
printf("punktfunk-core C ABI harness (abi_version=%u)\n", punktfunk_abi_version());
const uint32_t DROP_PERIOD = 8; /* drop 1 of every 8 packets */
PunktfunkConfig host_cfg = make_config(0, DROP_PERIOD);
PunktfunkConfig client_cfg = make_config(1, DROP_PERIOD);
PunktfunkSession *host = NULL;
PunktfunkSession *client = NULL;
PunktfunkStatus rc = punktfunk_test_loopback_pair(&host_cfg, &client_cfg, &host, &client);
if (rc != PUNKTFUNK_STATUS_OK || !host || !client) {
fprintf(stderr, "FAIL: loopback_pair rc=%d\n", (int)rc);
return 1;
}
const size_t FRAME_LEN = 200000; /* ~196 shards across 4 FEC blocks */
const int FRAMES = 4;
uint8_t *buf = (uint8_t *)malloc(FRAME_LEN);
if (!buf) { fprintf(stderr, "FAIL: oom\n"); return 1; }
int failures = 0;
for (int f = 0; f < FRAMES; f++) {
for (size_t i = 0; i < FRAME_LEN; i++) {
buf[i] = (uint8_t)((i * 131u) + (unsigned)f * 17u);
}
rc = punktfunk_host_submit_frame(host, buf, FRAME_LEN, (uint64_t)f * 1000000u, 0);
if (rc != PUNKTFUNK_STATUS_OK) {
fprintf(stderr, "FAIL: submit frame %d rc=%d\n", f, (int)rc);
failures++;
continue;
}
PunktfunkFrame out;
memset(&out, 0, sizeof(out));
rc = punktfunk_client_poll_frame(client, &out);
if (rc != PUNKTFUNK_STATUS_OK) {
fprintf(stderr, "FAIL: poll frame %d rc=%d (expected recovery)\n", f, (int)rc);
failures++;
continue;
}
if (out.len != FRAME_LEN || memcmp(out.data, buf, FRAME_LEN) != 0) {
fprintf(stderr, "FAIL: frame %d mismatch (len=%zu want=%zu)\n",
f, (size_t)out.len, FRAME_LEN);
failures++;
continue;
}
if (out.frame_index != (uint32_t)f) {
fprintf(stderr, "FAIL: frame %d wrong index %u\n", f, out.frame_index);
failures++;
}
}
PunktfunkStats st;
memset(&st, 0, sizeof(st));
punktfunk_get_stats(client, &st);
printf("client stats: completed=%llu recovered_shards=%llu dropped_pkts=%llu rx_pkts=%llu\n",
(unsigned long long)st.frames_completed,
(unsigned long long)st.fec_recovered_shards,
(unsigned long long)st.packets_dropped,
(unsigned long long)st.packets_received);
if (st.fec_recovered_shards == 0) {
fprintf(stderr, "FAIL: expected FEC to recover lost shards, but recovered 0\n");
failures++;
}
free(buf);
punktfunk_session_free(host);
punktfunk_session_free(client);
if (failures == 0) {
printf("PASS: %d frames round-tripped byte-exact through lossy loopback\n", FRAMES);
return 0;
}
fprintf(stderr, "FAILED with %d errors\n", failures);
return 1;
}
+34
View File
@@ -0,0 +1,34 @@
#!/usr/bin/env bash
# Build punktfunk-core's staticlib, then compile + link + run the C ABI harness against it.
# Proves the core links from C. Works on Linux and macOS (link flags come from rustc).
set -euo pipefail
here="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
ws="$(cd "$here/../../../.." && pwd)" # tests/c -> crates/punktfunk-core -> crates -> ws
cd "$ws"
profile="${1:-debug}"
build_flag=""
[ "$profile" = "release" ] && build_flag="--release"
echo ">> building punktfunk-core staticlib ($profile)"
cargo build -p punktfunk-core $build_flag >/dev/null
staticlib="$ws/target/$profile/libpunktfunk_core.a"
header_dir="$ws/include"
[ -f "$staticlib" ] || { echo "missing $staticlib"; exit 1; }
[ -f "$header_dir/punktfunk_core.h" ] || { echo "missing generated header"; exit 1; }
# Ask rustc what native libs the staticlib needs to link into a C program.
native_libs="$(cargo rustc -p punktfunk-core --lib --crate-type staticlib $build_flag -- \
--print native-static-libs 2>&1 | sed -n 's/.*native-static-libs: //p' | tail -1)"
echo ">> native libs: ${native_libs:-<none>}"
out="$(mktemp -d)/punktfunk_harness"
cc="${CC:-cc}"
echo ">> compiling + linking harness"
$cc -std=c11 -Wall -Wextra -O2 -I "$header_dir" \
"$here/harness.c" "$staticlib" $native_libs -o "$out"
echo ">> running"
"$out"
+85
View File
@@ -0,0 +1,85 @@
//! Runs the C ABI harness under `cargo test`: compiles `tests/c/harness.c`, links it
//! against the freshly built `libpunktfunk_core.a`, and asserts it round-trips frames
//! through the lossy loopback. The cross-platform canonical path (querying rustc for
//! link flags) is `tests/c/run.sh`; this mirrors it so `cargo test` alone covers the
//! C boundary.
use std::path::{Path, PathBuf};
use std::process::Command;
/// Native libs the Rust staticlib needs, minus the ones `cc` already links by default
/// (`-lSystem`/`-lc`), to avoid duplicate-library linker warnings. See
/// `rustc --print native-static-libs`.
fn native_libs() -> &'static [&'static str] {
if cfg!(target_os = "macos") {
&["-liconv", "-lm"]
} else if cfg!(target_os = "linux") {
&["-lgcc_s", "-lutil", "-lrt", "-lpthread", "-lm", "-ldl"]
} else {
&[]
}
}
fn ensure_staticlib(profile_dir: &Path) -> PathBuf {
let staticlib = profile_dir.join("libpunktfunk_core.a");
if !staticlib.exists() {
// `cargo test` doesn't always emit the standalone staticlib; build it. The
// outer cargo's build lock is released during test execution, so this is safe.
let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".into());
let _ = Command::new(cargo)
.args(["build", "-p", "punktfunk-core"])
.status();
}
staticlib
}
#[test]
fn c_abi_harness_round_trips() {
let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR")); // crates/punktfunk-core
let harness = manifest.join("tests/c/harness.c");
let include = manifest.join("../../include");
let exe = std::env::current_exe().expect("current_exe");
// .../target/<profile>/deps/c_abi-<hash> -> target/<profile>
let profile_dir = exe
.parent()
.and_then(Path::parent)
.expect("profile dir")
.to_path_buf();
let staticlib = ensure_staticlib(&profile_dir);
assert!(
staticlib.exists(),
"staticlib not found at {} (run `cargo build -p punktfunk-core`)",
staticlib.display()
);
assert!(
include.join("punktfunk_core.h").exists(),
"generated header missing; build punktfunk-core to regenerate it"
);
let cc = std::env::var("CC").unwrap_or_else(|_| "cc".into());
let out = profile_dir.join("punktfunk_c_harness");
let mut compile = Command::new(&cc);
compile
.args(["-std=c11", "-Wall", "-Wextra", "-O2", "-I"])
.arg(&include)
.arg(&harness)
.arg(&staticlib)
.args(native_libs())
.arg("-o")
.arg(&out);
match compile.status() {
Ok(s) => assert!(s.success(), "C harness failed to compile/link"),
Err(e) => {
// No C toolchain (unusual) — don't fail the whole suite; run.sh covers CI.
eprintln!("skipping C ABI test: cannot invoke `{cc}`: {e}");
return;
}
}
let run = Command::new(&out).status().expect("run C harness");
assert!(run.success(), "C harness reported a round-trip failure");
}
+185
View File
@@ -0,0 +1,185 @@
//! M1 acceptance: round-trip access units through the full host→client path
//! (packetize → FEC → loopback with simulated loss → recover → reassemble) and assert
//! byte-exact recovery, for both FEC schemes, with and without encryption. Plus
//! property tests over the FEC layer's loss patterns.
use proptest::prelude::*;
use punktfunk_core::config::{Config, FecConfig, FecScheme, ProtocolPhase, Role};
use punktfunk_core::fec::coder_for;
use punktfunk_core::input::{InputEvent, InputKind};
use punktfunk_core::session::Session;
use punktfunk_core::transport::loopback_pair;
fn config(role: Role, scheme: FecScheme, encrypt: bool, drop_period: u32) -> Config {
Config {
role,
phase: match scheme {
FecScheme::Gf8 => ProtocolPhase::P1GameStream,
FecScheme::Gf16 => ProtocolPhase::P2Punktfunk,
},
fec: FecConfig {
scheme,
fec_percent: 25,
max_data_per_block: 32,
},
shard_payload: 1024,
max_frame_bytes: 8 * 1024 * 1024,
encrypt,
key: [7u8; 16],
salt: [1, 2, 3, 4],
loopback_drop_period: drop_period,
}
}
/// Drive `frames` access units host→client over a lossy loopback and assert each one
/// comes back byte-identical. Returns the client's final stats.
fn run_stream(
scheme: FecScheme,
encrypt: bool,
drop_period: u32,
frames: &[Vec<u8>],
) -> punktfunk_core::Stats {
let (host_tp, client_tp) = loopback_pair(drop_period, 0);
let mut host = Session::new(
config(Role::Host, scheme, encrypt, drop_period),
Box::new(host_tp),
)
.unwrap();
let mut client = Session::new(
config(Role::Client, scheme, encrypt, drop_period),
Box::new(client_tp),
)
.unwrap();
for (i, frame) in frames.iter().enumerate() {
host.submit_frame(frame, i as u64 * 1_000_000, 0).unwrap();
let got = client
.poll_frame()
.expect("frame should recover despite loss");
assert_eq!(&got.data, frame, "frame {i} mismatched after recovery");
assert_eq!(got.frame_index, i as u32);
assert_eq!(got.pts_ns, i as u64 * 1_000_000);
}
client.stats()
}
fn sample_frames() -> Vec<Vec<u8>> {
(0..5usize)
.map(|f| {
let len = 1 + f * 40_000; // 1, 40k, 80k, 120k, 160k → single- and multi-block
(0..len)
.map(|b| (b.wrapping_mul(31).wrapping_add(f * 7)) as u8)
.collect()
})
.collect()
}
#[test]
fn gf8_stream_recovers_under_loss() {
let frames = sample_frames();
// drop_period 8 deletes the 1st of every 8 packets → real data-shard loss.
let stats = run_stream(FecScheme::Gf8, false, 8, &frames);
assert_eq!(stats.frames_completed, frames.len() as u64);
assert!(
stats.fec_recovered_shards > 0,
"loss should have forced FEC recovery"
);
}
#[test]
fn gf16_stream_recovers_under_loss() {
let frames = sample_frames();
let stats = run_stream(FecScheme::Gf16, false, 8, &frames);
assert_eq!(stats.frames_completed, frames.len() as u64);
assert!(stats.fec_recovered_shards > 0);
}
#[test]
fn encrypted_stream_recovers_under_loss() {
let frames = sample_frames();
let stats = run_stream(FecScheme::Gf8, true, 8, &frames);
assert_eq!(stats.frames_completed, frames.len() as u64);
}
#[test]
fn lossless_stream_is_exact() {
let frames = sample_frames();
let stats = run_stream(FecScheme::Gf16, false, 0, &frames);
assert_eq!(stats.frames_completed, frames.len() as u64);
assert_eq!(
stats.fec_recovered_shards, 0,
"no loss → nothing to recover"
);
}
#[test]
fn input_round_trips_client_to_host() {
let (host_tp, client_tp) = loopback_pair(0, 0);
let mut host = Session::new(
config(Role::Host, FecScheme::Gf8, false, 0),
Box::new(host_tp),
)
.unwrap();
let mut client = Session::new(
config(Role::Client, FecScheme::Gf8, false, 0),
Box::new(client_tp),
)
.unwrap();
let sent = InputEvent {
kind: InputKind::MouseMove,
_pad: [0; 3],
code: 0,
x: -7,
y: 13,
flags: 0,
};
client.send_input(&sent).unwrap();
let got = host
.poll_input()
.unwrap()
.expect("host should receive the input event");
assert_eq!(got, sent);
}
// ---- property tests over the FEC layer --------------------------------------
proptest! {
/// For random shard counts and an erasure set within the recovery budget, every
/// original shard is reconstructed byte-identically — for both backends.
#[test]
fn fec_recovers_any_loss_within_budget(
k in 1usize..40,
extra in 0usize..16, // recovery beyond the bare minimum
shard_half in 1usize..64, // shard_len = 2*shard_half (even)
seed in any::<u64>(),
) {
let m = (extra + 1).min(40);
let shard_len = shard_half * 2;
for coder in [coder_for(FecScheme::Gf8), coder_for(FecScheme::Gf16)] {
// Gf8 ceiling: data + recovery <= 255.
if matches!(coder.scheme(), FecScheme::Gf8) && k + m > 255 { continue; }
let data: Vec<Vec<u8>> = (0..k)
.map(|i| (0..shard_len).map(|b| (i ^ b).wrapping_add(seed as usize) as u8).collect())
.collect();
let recovery = coder.encode(&data, m).unwrap();
let mut received: Vec<Option<Vec<u8>>> =
data.iter().cloned().map(Some).chain(recovery.into_iter().map(Some)).collect();
// Erase up to `m` shards chosen by a cheap PRNG over the seed.
let total = k + m;
let lose = (seed as usize % (m + 1)).min(m);
let mut s = seed | 1;
for _ in 0..lose {
s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
let idx = (s >> 33) as usize % total;
received[idx] = None;
}
let restored = coder.reconstruct(k, m, &mut received).unwrap();
prop_assert_eq!(restored, data);
}
}
}
+17
View File
@@ -0,0 +1,17 @@
[package]
name = "fec-rs"
version = "0.1.0"
edition = "2021"
description = "A pure Rust Reed-Solomon erasure coding library with runtime SIMD acceleration"
license = "BSD-2-Clause"
repository = "https://github.com/hgaiser/fec-rs"
keywords = ["reed-solomon", "erasure", "coding", "fec", "simd"]
categories = ["algorithms", "encoding"]
readme = "README.md"
[dependencies]
rayon = { version = "1", optional = true }
[features]
default = []
parallel = ["rayon"]
+24
View File
@@ -0,0 +1,24 @@
BSD 2-Clause License
Copyright (c) 2026, Hans Gaiser
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+73
View File
@@ -0,0 +1,73 @@
# fec-rs
[![CI](https://github.com/hgaiser/fec-rs/workflows/CI/badge.svg)](https://github.com/hgaiser/fec-rs/actions)
[![Crates.io](https://img.shields.io/crates/v/fec-rs.svg)](https://crates.io/crates/fec-rs)
[![Documentation](https://docs.rs/fec-rs/badge.svg)](https://docs.rs/fec-rs)
A pure Rust Reed-Solomon erasure coding library with runtime SIMD acceleration.
## Features
- **Pure Rust** — No C/C++ dependencies or FFI. Everything is implemented in safe Rust
(with targeted `unsafe` for SIMD intrinsics).
- **Runtime SIMD detection** — Automatically uses the fastest available instruction set
via `std::is_x86_feature_detected!`. A single binary works on all x86_64 systems.
- **GF(2^8)** — Operates over the Galois field GF(2^8) with generating polynomial 29 (0x1D),
compatible with the Moonlight streaming protocol.
- **Shard-by-shard encoding** — Incremental encoding via `ShardByShard` for streaming use cases.
- **Reconstruction** — Reconstruct missing data and/or parity shards from any sufficient subset.
## SIMD Acceleration
On x86_64, the library automatically detects CPU features at runtime and uses
the best available instruction set:
- **GFNI + AVX2** — Single-instruction GF multiply on 32 bytes (Intel Alder Lake+, AMD Zen 4+)
- **AVX2** — VPSHUFB split-table nibble lookup on 32 bytes
- **GFNI + SSE** — Single-instruction GF multiply on 16 bytes
- **SSSE3** — VPSHUFB split-table nibble lookup on 16 bytes
- **Scalar** — Lookup table fallback
## Parallel Encoding
Enable the `parallel` feature for optional rayon-based parallel encoding:
```toml
fec-rs = { version = "0.1", features = ["parallel"] }
```
When enabled, large encode workloads automatically distribute parity shard
computation across threads. Small workloads use the sequential path to avoid
overhead.
## Usage
```rust
use fec_rs::ReedSolomon;
let rs = ReedSolomon::new(4, 2).unwrap();
let mut shards: Vec<Vec<u8>> = vec![
vec![0, 1, 2, 3],
vec![4, 5, 6, 7],
vec![8, 9, 10, 11],
vec![12, 13, 14, 15],
vec![0, 0, 0, 0], // parity shard 1
vec![0, 0, 0, 0], // parity shard 2
];
// Encode parity
rs.encode(&mut shards).unwrap();
// Verify
assert!(rs.verify(&shards).unwrap());
// Simulate loss of shard 0
let mut recovery: Vec<Option<Vec<u8>>> = shards.into_iter().map(Some).collect();
recovery[0] = None;
// Reconstruct
rs.reconstruct(&mut recovery).unwrap();
```
License: BSD-2-Clause
+200
View File
@@ -0,0 +1,200 @@
#![allow(clippy::needless_range_loop)]
use std::env;
use std::fs::File;
use std::io::Write;
use std::path::Path;
const FIELD_SIZE: usize = 256;
const GENERATING_POLYNOMIAL: usize = 29;
fn gen_log_table(polynomial: usize) -> [u8; FIELD_SIZE] {
let mut result = [0u8; FIELD_SIZE];
let mut b: usize = 1;
for log in 0..FIELD_SIZE - 1 {
result[b] = log as u8;
b <<= 1;
if FIELD_SIZE <= b {
b = (b - FIELD_SIZE) ^ polynomial;
}
}
result
}
const EXP_TABLE_SIZE: usize = FIELD_SIZE * 2 - 2;
fn gen_exp_table(log_table: &[u8; FIELD_SIZE]) -> [u8; EXP_TABLE_SIZE] {
let mut result = [0u8; EXP_TABLE_SIZE];
for i in 1..FIELD_SIZE {
let log = log_table[i] as usize;
result[log] = i as u8;
result[log + FIELD_SIZE - 1] = i as u8;
}
result
}
fn multiply(log_table: &[u8; FIELD_SIZE], exp_table: &[u8; EXP_TABLE_SIZE], a: u8, b: u8) -> u8 {
if a == 0 || b == 0 {
0
} else {
let log_a = log_table[a as usize];
let log_b = log_table[b as usize];
let log_result = log_a as usize + log_b as usize;
exp_table[log_result]
}
}
fn gen_mul_table(
log_table: &[u8; FIELD_SIZE],
exp_table: &[u8; EXP_TABLE_SIZE],
) -> [[u8; FIELD_SIZE]; FIELD_SIZE] {
let mut result = [[0u8; FIELD_SIZE]; FIELD_SIZE];
for a in 0..FIELD_SIZE {
for b in 0..FIELD_SIZE {
result[a][b] = multiply(log_table, exp_table, a as u8, b as u8);
}
}
result
}
fn gen_mul_table_half(
log_table: &[u8; FIELD_SIZE],
exp_table: &[u8; EXP_TABLE_SIZE],
) -> ([[u8; 16]; FIELD_SIZE], [[u8; 16]; FIELD_SIZE]) {
let mut low = [[0u8; 16]; FIELD_SIZE];
let mut high = [[0u8; 16]; FIELD_SIZE];
for a in 0..FIELD_SIZE {
for b in 0..FIELD_SIZE {
let mut result = 0;
if a != 0 && b != 0 {
let log_a = log_table[a];
let log_b = log_table[b];
result = exp_table[log_a as usize + log_b as usize];
}
if (b & 0x0F) == b {
low[a][b] = result;
}
if (b & 0xF0) == b {
high[a][b >> 4] = result;
}
}
}
(low, high)
}
/// Generate the GFNI affine matrix table.
///
/// For each constant `c` in GF(2^8), compute a u64-packed 8x8 binary matrix
/// such that `vgf2p8affineqb(x, matrix, 0)` produces `c * x` in our GF(2^8).
///
/// vgf2p8affineqb semantics:
/// result_bit[i] = popcount(x AND qword_byte[7-i]) mod 2
/// where i goes from 0 (LSB) to 7 (MSB).
///
/// Matrix packing: qword byte[7] = row for output bit 7 (MSB),
/// qword byte[0] = row for output bit 0 (LSB).
fn gen_gfni_table(
log_table: &[u8; FIELD_SIZE],
exp_table: &[u8; EXP_TABLE_SIZE],
) -> [u64; FIELD_SIZE] {
let mut result = [0u64; FIELD_SIZE];
for c in 0..FIELD_SIZE {
// Build row bytes for each output bit.
// row_for_bit_i = mask where bit j is set iff input bit j contributes to output bit i.
// M[i][j] = bit_i(c * (1 << j))
let mut rows = [0u8; 8];
for j in 0..8u8 {
let basis = 1u8 << j; // input with only bit j set
let product = multiply(log_table, exp_table, c as u8, basis);
// product's bit i tells us M[i][j]
for i in 0..8u8 {
if (product >> i) & 1 == 1 {
rows[i as usize] |= 1 << j;
}
}
}
// Pack into u64: byte[7-i] = rows[i]
// vgf2p8affineqb: result_bit[i] = popcount(x AND byte[7-i]) mod 2
// We want result_bit[i] = bit i of (c*x), so byte[7-i] = rows[i].
let mut matrix: u64 = 0;
for i in 0..8u32 {
matrix |= (rows[i as usize] as u64) << ((7 - i) * 8);
}
result[c] = matrix;
}
result
}
fn write_1d_table(f: &mut File, table: &[u8], name: &str) {
let len = table.len();
write!(f, "pub static {name}: [u8; {len}] = [").unwrap();
for v in table {
write!(f, "{v}, ").unwrap();
}
writeln!(f, "];").unwrap();
}
fn write_2d_table(f: &mut File, table: &[[u8; 16]; FIELD_SIZE], name: &str) {
let rows = table.len();
let cols = table[0].len();
write!(f, "pub static {name}: [[u8; {cols}]; {rows}] = [").unwrap();
for row in table {
write!(f, "[").unwrap();
for v in row {
write!(f, "{v}, ").unwrap();
}
writeln!(f, "],").unwrap();
}
writeln!(f, "];").unwrap();
}
fn write_mul_table(f: &mut File, table: &[[u8; FIELD_SIZE]; FIELD_SIZE]) {
let rows = table.len();
let cols = table[0].len();
write!(f, "pub static MUL_TABLE: [[u8; {cols}]; {rows}] = [").unwrap();
for row in table {
write!(f, "[").unwrap();
for v in row {
write!(f, "{v}, ").unwrap();
}
writeln!(f, "],").unwrap();
}
writeln!(f, "];").unwrap();
}
fn write_gfni_table(f: &mut File, table: &[u64; FIELD_SIZE]) {
write!(f, "pub static GFNI_TABLE: [u64; {}] = [", FIELD_SIZE).unwrap();
for v in table {
write!(f, "0x{v:016X}, ").unwrap();
}
writeln!(f, "];").unwrap();
}
fn main() {
let log_table = gen_log_table(GENERATING_POLYNOMIAL);
let exp_table = gen_exp_table(&log_table);
let mul_table = gen_mul_table(&log_table, &exp_table);
let (mul_table_low, mul_table_high) = gen_mul_table_half(&log_table, &exp_table);
let gfni_table = gen_gfni_table(&log_table, &exp_table);
let out_dir = env::var("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("tables.rs");
let mut f = File::create(&dest_path).unwrap();
write_1d_table(&mut f, &log_table, "LOG_TABLE");
write_1d_table(&mut f, &exp_table, "EXP_TABLE");
write_mul_table(&mut f, &mul_table);
write_2d_table(&mut f, &mul_table_low, "MUL_TABLE_LOW");
write_2d_table(&mut f, &mul_table_high, "MUL_TABLE_HIGH");
write_gfni_table(&mut f, &gfni_table);
}
+61
View File
@@ -0,0 +1,61 @@
use core::fmt;
#[derive(PartialEq, Debug, Clone, Copy)]
pub enum Error {
TooFewShards,
TooManyShards,
TooFewDataShards,
TooManyDataShards,
TooFewParityShards,
TooManyParityShards,
TooFewBufferShards,
TooManyBufferShards,
IncorrectShardSize,
TooFewShardsPresent,
EmptyShard,
InvalidIndex,
InvalidParityMatrix,
SingularMatrix,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::TooFewShards => write!(f, "Too few shards"),
Error::TooManyShards => write!(f, "Too many shards"),
Error::TooFewDataShards => write!(f, "Too few data shards"),
Error::TooManyDataShards => write!(f, "Too many data shards"),
Error::TooFewParityShards => write!(f, "Too few parity shards"),
Error::TooManyParityShards => write!(f, "Too many parity shards"),
Error::TooFewBufferShards => write!(f, "Too few buffer shards"),
Error::TooManyBufferShards => write!(f, "Too many buffer shards"),
Error::IncorrectShardSize => write!(f, "Incorrect shard size"),
Error::TooFewShardsPresent => write!(f, "Too few shards present for reconstruction"),
Error::EmptyShard => write!(f, "Empty shard"),
Error::InvalidIndex => write!(f, "Invalid index"),
Error::InvalidParityMatrix => write!(f, "Invalid parity matrix"),
Error::SingularMatrix => write!(f, "Singular matrix"),
}
}
}
impl std::error::Error for Error {}
#[derive(PartialEq, Debug, Clone, Copy)]
pub enum SBSError {
TooManyCalls,
LeftoverShards,
RSError(Error),
}
impl fmt::Display for SBSError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SBSError::TooManyCalls => write!(f, "Too many calls"),
SBSError::LeftoverShards => write!(f, "Leftover shards"),
SBSError::RSError(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for SBSError {}
+636
View File
@@ -0,0 +1,636 @@
include!(concat!(env!("OUT_DIR"), "/tables.rs"));
/// Add two GF(2^8) elements (XOR).
#[inline(always)]
pub fn add(a: u8, b: u8) -> u8 {
a ^ b
}
/// Multiply two GF(2^8) elements using lookup table.
#[inline(always)]
pub fn mul(a: u8, b: u8) -> u8 {
MUL_TABLE[a as usize][b as usize]
}
/// Divide a by b in GF(2^8). Panics if b is 0.
#[inline(always)]
pub fn div(a: u8, b: u8) -> u8 {
if a == 0 {
return 0;
}
assert!(b != 0, "Division by zero in GF(2^8)");
let log_a = LOG_TABLE[a as usize] as isize;
let log_b = LOG_TABLE[b as usize] as isize;
let mut log_result = log_a - log_b;
if log_result < 0 {
log_result += 255;
}
EXP_TABLE[log_result as usize]
}
/// Compute a^n in GF(2^8).
#[inline(always)]
pub fn exp(a: u8, n: usize) -> u8 {
if n == 0 {
return 1;
}
if a == 0 {
return 0;
}
let log_a = LOG_TABLE[a as usize] as usize;
let log_result = log_a * (n % 255) % 255;
EXP_TABLE[log_result]
}
/// Multiply each element of `input` by `c` and write to `out`.
///
/// Uses SIMD acceleration when available:
/// - GFNI + AVX2 (best: single-instruction GF multiply on 32 bytes)
/// - AVX2 VPSHUFB (split-table nibble lookup on 32 bytes)
/// - GFNI + SSE (single-instruction GF multiply on 16 bytes)
/// - SSSE3 VPSHUFB (split-table nibble lookup on 16 bytes)
/// - Scalar fallback
#[inline]
pub fn mul_slice(c: u8, input: &[u8], out: &mut [u8]) {
assert_eq!(input.len(), out.len());
if input.is_empty() || c == 0 {
out.iter_mut().for_each(|o| *o = 0);
return;
}
if c == 1 {
out.copy_from_slice(input);
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("gfni") && is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_gfni_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("gfni") {
unsafe {
mul_slice_gfni_sse(c, input, out);
}
return;
}
if is_x86_feature_detected!("ssse3") {
unsafe {
mul_slice_ssse3(c, input, out);
}
return;
}
}
mul_slice_scalar(c, input, out);
}
/// Multiply each element of `input` by `c` and XOR into `out`.
///
/// Uses SIMD acceleration when available (same priority as `mul_slice`).
#[inline]
pub fn mul_slice_xor(c: u8, input: &[u8], out: &mut [u8]) {
assert_eq!(input.len(), out.len());
if input.is_empty() || c == 0 {
return;
}
if c == 1 {
for (o, i) in out.iter_mut().zip(input.iter()) {
*o ^= *i;
}
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("gfni") && is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_xor_gfni_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_xor_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("gfni") {
unsafe {
mul_slice_xor_gfni_sse(c, input, out);
}
return;
}
if is_x86_feature_detected!("ssse3") {
unsafe {
mul_slice_xor_ssse3(c, input, out);
}
return;
}
}
mul_slice_xor_scalar(c, input, out);
}
/// Function pointer types for bulk GF(2^8) operations.
pub type MulSliceFn = fn(u8, &[u8], &mut [u8]);
/// Pair of (mul_slice, mul_slice_xor) function pointers for the best available SIMD path.
///
/// Unlike `mul_slice`/`mul_slice_xor`, these skip runtime feature detection on every call.
/// The caller checks once and stores the result.
///
/// Note: These raw dispatch functions do NOT handle the c==0 or c==1 special cases.
/// The caller must handle those before calling through the function pointer.
pub fn detect_mul_slice() -> (MulSliceFn, MulSliceFn) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("gfni") && is_x86_feature_detected!("avx2") {
return (
wrap_mul_slice_gfni_avx2 as MulSliceFn,
wrap_mul_slice_xor_gfni_avx2 as MulSliceFn,
);
}
if is_x86_feature_detected!("avx2") {
return (
wrap_mul_slice_avx2 as MulSliceFn,
wrap_mul_slice_xor_avx2 as MulSliceFn,
);
}
if is_x86_feature_detected!("gfni") {
return (
wrap_mul_slice_gfni_sse as MulSliceFn,
wrap_mul_slice_xor_gfni_sse as MulSliceFn,
);
}
if is_x86_feature_detected!("ssse3") {
return (
wrap_mul_slice_ssse3 as MulSliceFn,
wrap_mul_slice_xor_ssse3 as MulSliceFn,
);
}
}
(
mul_slice_scalar as MulSliceFn,
mul_slice_xor_scalar as MulSliceFn,
)
}
// Safe wrappers for SIMD functions (used as function pointer targets)
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_gfni_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_gfni_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_gfni_sse(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_gfni_sse(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_ssse3(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_ssse3(c, input, out) }
}
// ── Scalar fallback ──────────────────────────────────────────────────────
fn mul_slice_scalar(c: u8, input: &[u8], out: &mut [u8]) {
let mt = &MUL_TABLE[c as usize];
for (o, &i) in out.iter_mut().zip(input.iter()) {
*o = mt[i as usize];
}
}
fn mul_slice_xor_scalar(c: u8, input: &[u8], out: &mut [u8]) {
let mt = &MUL_TABLE[c as usize];
for (o, &i) in out.iter_mut().zip(input.iter()) {
*o ^= mt[i as usize];
}
}
// ── x86_64 SIMD implementations ─────────────────────────────────────────
// ── GFNI + AVX2 (best path: 32 bytes per vgf2p8affineqb) ──────────────
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni,avx2")]
unsafe fn mul_slice_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm256_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let result = _mm256_gf2p8affine_epi64_epi8(data, mat_vec, 0);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni,avx2")]
unsafe fn mul_slice_xor_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm256_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let existing = _mm256_loadu_si256(out.as_ptr().add(i) as *const _);
let mul_result = _mm256_gf2p8affine_epi64_epi8(data, mat_vec, 0);
let result = _mm256_xor_si256(mul_result, existing);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
// ── GFNI + SSE (16 bytes per vgf2p8affineqb) ──────────────────────────
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni")]
unsafe fn mul_slice_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let result = _mm_gf2p8affine_epi64_epi8(data, mat_vec, 0);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni")]
unsafe fn mul_slice_xor_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let existing = _mm_loadu_si128(out.as_ptr().add(i) as *const _);
let mul_result = _mm_gf2p8affine_epi64_epi8(data, mat_vec, 0);
let result = _mm_xor_si128(mul_result, existing);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
// ── AVX2 VPSHUFB (32 bytes, split-table nibble lookup) ─────────────────
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn mul_slice_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
// Broadcast the 16-byte low/high tables to 256-bit registers by duplicating
let low_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(low.as_ptr() as *const _));
let high_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(high.as_ptr() as *const _));
let mask = _mm256_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
// Process 32 bytes at a time
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let lo_nibble = _mm256_and_si256(data, mask);
let hi_nibble = _mm256_and_si256(_mm256_srli_epi64(data, 4), mask);
let lo_result = _mm256_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm256_shuffle_epi8(high_vec, hi_nibble);
let result = _mm256_xor_si256(lo_result, hi_result);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
// Handle remaining bytes with scalar
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn mul_slice_xor_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
let low_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(low.as_ptr() as *const _));
let high_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(high.as_ptr() as *const _));
let mask = _mm256_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let existing = _mm256_loadu_si256(out.as_ptr().add(i) as *const _);
let lo_nibble = _mm256_and_si256(data, mask);
let hi_nibble = _mm256_and_si256(_mm256_srli_epi64(data, 4), mask);
let lo_result = _mm256_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm256_shuffle_epi8(high_vec, hi_nibble);
let result = _mm256_xor_si256(_mm256_xor_si256(lo_result, hi_result), existing);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn mul_slice_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
let low_vec = _mm_loadu_si128(low.as_ptr() as *const _);
let high_vec = _mm_loadu_si128(high.as_ptr() as *const _);
let mask = _mm_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let lo_nibble = _mm_and_si128(data, mask);
let hi_nibble = _mm_and_si128(_mm_srli_epi64(data, 4), mask);
let lo_result = _mm_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm_shuffle_epi8(high_vec, hi_nibble);
let result = _mm_xor_si128(lo_result, hi_result);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn mul_slice_xor_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
let low_vec = _mm_loadu_si128(low.as_ptr() as *const _);
let high_vec = _mm_loadu_si128(high.as_ptr() as *const _);
let mask = _mm_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let existing = _mm_loadu_si128(out.as_ptr().add(i) as *const _);
let lo_nibble = _mm_and_si128(data, mask);
let hi_nibble = _mm_and_si128(_mm_srli_epi64(data, 4), mask);
let lo_result = _mm_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm_shuffle_epi8(high_vec, hi_nibble);
let result = _mm_xor_si128(_mm_xor_si128(lo_result, hi_result), existing);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gfni_table() {
// Verify GFNI_TABLE by emulating vgf2p8affineqb in software:
// result_bit[i] = popcount(x AND qword_byte[7-i]) mod 2
for c in 0u16..256 {
let matrix = GFNI_TABLE[c as usize];
for b in 0u16..256 {
let expected = MUL_TABLE[c as usize][b as usize];
let x = b as u8;
let mut result: u8 = 0;
for i in 0..8u32 {
let row_byte = ((matrix >> ((7 - i) * 8)) & 0xFF) as u8;
let dot = (row_byte & x).count_ones() % 2;
result |= (dot as u8) << i;
}
assert_eq!(
result, expected,
"GFNI table mismatch: c={c}, b={b}, got={result}, expected={expected}"
);
}
}
}
#[test]
fn test_add() {
assert_eq!(add(0, 0), 0);
assert_eq!(add(1, 0), 1);
assert_eq!(add(0, 1), 1);
assert_eq!(add(1, 1), 0);
assert_eq!(add(0xFF, 0xFF), 0);
assert_eq!(add(0xAA, 0x55), 0xFF);
}
#[test]
fn test_mul() {
assert_eq!(mul(0, 0), 0);
assert_eq!(mul(1, 0), 0);
assert_eq!(mul(0, 1), 0);
assert_eq!(mul(1, 1), 1);
// a * 1 = a
for a in 0u8..=255 {
assert_eq!(mul(a, 1), a);
assert_eq!(mul(1, a), a);
}
// a * 0 = 0
for a in 0u8..=255 {
assert_eq!(mul(a, 0), 0);
}
}
#[test]
fn test_div() {
// a / 1 = a
for a in 0u8..=255 {
assert_eq!(div(a, 1), a);
}
// a / a = 1 (for a != 0)
for a in 1u8..=255 {
assert_eq!(div(a, a), 1);
}
// (a * b) / b = a
for a in 1u8..=255 {
for b in 1u8..=255 {
assert_eq!(div(mul(a, b), b), a);
}
}
}
#[test]
fn test_exp() {
assert_eq!(exp(0, 0), 1);
assert_eq!(exp(1, 0), 1);
assert_eq!(exp(5, 0), 1);
assert_eq!(exp(0, 1), 0);
assert_eq!(exp(0, 100), 0);
// a^1 = a
for a in 0u8..=255 {
assert_eq!(exp(a, 1), a);
}
// a^2 = a * a
for a in 0u8..=255 {
assert_eq!(exp(a, 2), mul(a, a));
}
}
#[test]
fn test_mul_slice_basic() {
let input = [1u8, 2, 3, 4, 5, 6, 7, 8];
let mut out = [0u8; 8];
mul_slice(3, &input, &mut out);
for i in 0..input.len() {
assert_eq!(out[i], mul(3, input[i]));
}
}
#[test]
fn test_mul_slice_xor_basic() {
let input = [1u8, 2, 3, 4, 5, 6, 7, 8];
let mut out = [10u8; 8];
let original = out;
mul_slice_xor(3, &input, &mut out);
for i in 0..input.len() {
assert_eq!(out[i], original[i] ^ mul(3, input[i]));
}
}
#[test]
fn test_mul_slice_large() {
// Test with a buffer large enough to exercise SIMD paths
let input: Vec<u8> = (0..256).map(|i| i as u8).collect();
let mut out = vec![0u8; 256];
let mut expected = vec![0u8; 256];
for c in [2u8, 7, 42, 128, 255] {
mul_slice_scalar(c, &input, &mut expected);
mul_slice(c, &input, &mut out);
assert_eq!(out, expected, "mul_slice mismatch for c={c}");
}
}
#[test]
fn test_mul_slice_xor_large() {
let input: Vec<u8> = (0..256).map(|i| i as u8).collect();
for c in [2u8, 7, 42, 128, 255] {
let mut out_expected = vec![0xABu8; 256];
let mut out_simd = out_expected.clone();
mul_slice_xor_scalar(c, &input, &mut out_expected);
mul_slice_xor(c, &input, &mut out_simd);
assert_eq!(out_simd, out_expected, "mul_slice_xor mismatch for c={c}");
}
}
#[test]
fn test_mul_slice_unaligned_sizes() {
// Test sizes that don't align to SIMD width
for size in [1, 7, 15, 16, 17, 31, 32, 33, 63, 64, 65, 100] {
let input: Vec<u8> = (0..size).map(|i| i as u8).collect();
let mut out = vec![0u8; size];
let mut expected = vec![0u8; size];
mul_slice_scalar(42, &input, &mut expected);
mul_slice(42, &input, &mut out);
assert_eq!(out, expected, "mul_slice mismatch for size={size}");
}
}
}
+73
View File
@@ -0,0 +1,73 @@
//! A pure Rust Reed-Solomon erasure coding library with runtime SIMD acceleration.
//!
//! # Features
//!
//! - **Pure Rust** — No C/C++ dependencies or FFI. Everything is implemented in safe Rust
//! (with targeted `unsafe` for SIMD intrinsics).
//! - **Runtime SIMD detection** — Automatically uses the fastest available instruction set
//! via `std::is_x86_feature_detected!`. A single binary works on all x86_64 systems.
//! - **GF(2^8)** — Operates over the Galois field GF(2^8) with generating polynomial 29 (0x1D),
//! compatible with the Moonlight streaming protocol.
//! - **Shard-by-shard encoding** — Incremental encoding via `ShardByShard` for streaming use cases.
//! - **Reconstruction** — Reconstruct missing data and/or parity shards from any sufficient subset.
//!
//! # SIMD Acceleration
//!
//! On x86_64, the library automatically detects CPU features at runtime and uses
//! the best available instruction set:
//!
//! - **GFNI + AVX2** — Single-instruction GF multiply on 32 bytes (Intel Alder Lake+, AMD Zen 4+)
//! - **AVX2** — VPSHUFB split-table nibble lookup on 32 bytes
//! - **GFNI + SSE** — Single-instruction GF multiply on 16 bytes
//! - **SSSE3** — VPSHUFB split-table nibble lookup on 16 bytes
//! - **Scalar** — Lookup table fallback
//!
//! # Parallel Encoding
//!
//! Enable the `parallel` feature for optional rayon-based parallel encoding:
//!
//! ```toml
//! fec-rs = { version = "0.1", features = ["parallel"] }
//! ```
//!
//! When enabled, large encode workloads automatically distribute parity shard
//! computation across threads. Small workloads use the sequential path to avoid
//! overhead.
//!
//! # Usage
//!
//! ```
//! use fec_rs::ReedSolomon;
//!
//! let rs = ReedSolomon::new(4, 2).unwrap();
//!
//! let mut shards: Vec<Vec<u8>> = vec![
//! vec![0, 1, 2, 3],
//! vec![4, 5, 6, 7],
//! vec![8, 9, 10, 11],
//! vec![12, 13, 14, 15],
//! vec![0, 0, 0, 0], // parity shard 1
//! vec![0, 0, 0, 0], // parity shard 2
//! ];
//!
//! // Encode parity
//! rs.encode(&mut shards).unwrap();
//!
//! // Verify
//! assert!(rs.verify(&shards).unwrap());
//!
//! // Simulate loss of shard 0
//! let mut recovery: Vec<Option<Vec<u8>>> = shards.into_iter().map(Some).collect();
//! recovery[0] = None;
//!
//! // Reconstruct
//! rs.reconstruct(&mut recovery).unwrap();
//! ```
mod errors;
pub mod galois;
mod matrix;
mod reed_solomon;
pub use errors::{Error, SBSError};
pub use reed_solomon::{ReconstructShard, ReedSolomon, ShardByShard};
+251
View File
@@ -0,0 +1,251 @@
use crate::galois;
#[derive(PartialEq, Debug, Clone)]
pub struct Matrix {
pub row_count: usize,
pub col_count: usize,
pub data: Vec<u8>,
}
impl Matrix {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
row_count: rows,
col_count: cols,
data: vec![0u8; rows * cols],
}
}
pub fn identity(size: usize) -> Self {
let mut m = Self::new(size, size);
for i in 0..size {
m.data[i * size + i] = 1;
}
m
}
pub fn vandermonde(rows: usize, cols: usize) -> Self {
let mut m = Self::new(rows, cols);
for r in 0..rows {
let r_a = r as u8;
for c in 0..cols {
m.data[r * cols + c] = galois::exp(r_a, c);
}
}
m
}
#[inline]
pub fn get(&self, r: usize, c: usize) -> u8 {
self.data[r * self.col_count + c]
}
#[inline]
pub fn set(&mut self, r: usize, c: usize, val: u8) {
self.data[r * self.col_count + c] = val;
}
pub fn get_row(&self, row: usize) -> &[u8] {
let start = row * self.col_count;
&self.data[start..start + self.col_count]
}
pub fn sub_matrix(&self, rmin: usize, cmin: usize, rmax: usize, cmax: usize) -> Self {
let new_rows = rmax - rmin;
let new_cols = cmax - cmin;
let mut m = Self::new(new_rows, new_cols);
for r in rmin..rmax {
for c in cmin..cmax {
m.data[(r - rmin) * new_cols + (c - cmin)] = self.get(r, c);
}
}
m
}
pub fn multiply(&self, rhs: &Matrix) -> Self {
assert_eq!(
self.col_count, rhs.row_count,
"Matrix dimensions incompatible for multiply"
);
let mut result = Self::new(self.row_count, rhs.col_count);
for r in 0..self.row_count {
for c in 0..rhs.col_count {
let mut val = 0u8;
for i in 0..self.col_count {
val = galois::add(val, galois::mul(self.get(r, i), rhs.get(i, c)));
}
result.set(r, c, val);
}
}
result
}
pub fn augment(&self, rhs: &Matrix) -> Self {
assert_eq!(
self.row_count, rhs.row_count,
"Matrix row counts must match for augment"
);
let new_cols = self.col_count + rhs.col_count;
let mut m = Self::new(self.row_count, new_cols);
for r in 0..self.row_count {
for c in 0..self.col_count {
m.set(r, c, self.get(r, c));
}
for c in 0..rhs.col_count {
m.set(r, self.col_count + c, rhs.get(r, c));
}
}
m
}
fn swap_rows(&mut self, r1: usize, r2: usize) {
if r1 == r2 {
return;
}
let s1 = r1 * self.col_count;
let s2 = r2 * self.col_count;
for i in 0..self.col_count {
self.data.swap(s1 + i, s2 + i);
}
}
fn gaussian_elim(&mut self) -> Result<(), &'static str> {
for r in 0..self.row_count {
// Pivot search
if self.get(r, r) == 0 {
for r_below in r + 1..self.row_count {
if self.get(r_below, r) != 0 {
self.swap_rows(r, r_below);
break;
}
}
}
if self.get(r, r) == 0 {
return Err("Singular matrix");
}
// Scale to 1
if self.get(r, r) != 1 {
let scale = galois::div(1, self.get(r, r));
for c in 0..self.col_count {
let val = galois::mul(scale, self.get(r, c));
self.set(r, c, val);
}
}
// Eliminate below
for r_below in r + 1..self.row_count {
if self.get(r_below, r) != 0 {
let scale = self.get(r_below, r);
for c in 0..self.col_count {
let val =
galois::add(self.get(r_below, c), galois::mul(scale, self.get(r, c)));
self.set(r_below, c, val);
}
}
}
}
// Back substitution
for d in 0..self.row_count {
for r_above in 0..d {
if self.get(r_above, d) != 0 {
let scale = self.get(r_above, d);
for c in 0..self.col_count {
let val =
galois::add(self.get(r_above, c), galois::mul(scale, self.get(d, c)));
self.set(r_above, c, val);
}
}
}
}
Ok(())
}
pub fn invert(&self) -> Result<Self, &'static str> {
assert!(
self.row_count == self.col_count,
"Cannot invert non-square matrix"
);
let mut work = self.augment(&Self::identity(self.row_count));
work.gaussian_elim()?;
Ok(work.sub_matrix(0, self.row_count, self.col_count, self.col_count * 2))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mat(data: Vec<Vec<u8>>) -> Matrix {
let rows = data.len();
let cols = data[0].len();
let flat: Vec<u8> = data.into_iter().flatten().collect();
Matrix {
row_count: rows,
col_count: cols,
data: flat,
}
}
#[test]
fn test_identity() {
let m = Matrix::identity(3);
let expected = mat(vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]]);
assert_eq!(m, expected);
}
#[test]
fn test_multiply() {
let m1 = mat(vec![vec![1, 2], vec![3, 4]]);
let m2 = mat(vec![vec![5, 6], vec![7, 8]]);
let result = m1.multiply(&m2);
let expected = mat(vec![vec![11, 22], vec![19, 42]]);
assert_eq!(result, expected);
}
#[test]
fn test_invert() {
let m = mat(vec![
vec![56, 23, 98],
vec![3, 100, 200],
vec![45, 201, 123],
]);
let inv = m.invert().unwrap();
let expected = mat(vec![
vec![175, 133, 33],
vec![130, 13, 245],
vec![112, 35, 126],
]);
assert_eq!(inv, expected);
}
#[test]
fn test_invert_identity() {
let m = Matrix::identity(4);
let inv = m.invert().unwrap();
assert_eq!(inv, m);
}
#[test]
fn test_multiply_identity() {
let m = mat(vec![
vec![56, 23, 98],
vec![3, 100, 200],
vec![45, 201, 123],
]);
let id = Matrix::identity(3);
assert_eq!(m.multiply(&id), m);
assert_eq!(id.multiply(&m), m);
}
#[test]
fn test_invert_times_original_is_identity() {
let m = mat(vec![
vec![56, 23, 98],
vec![3, 100, 200],
vec![45, 201, 123],
]);
let inv = m.invert().unwrap();
let product = m.multiply(&inv);
assert_eq!(product, Matrix::identity(3));
}
}
File diff suppressed because it is too large Load Diff