feat: M2 P1.5 (FEC) — nanors-exact Reed-Solomon recovery for the video stream

Moonlight now reconstructs lost video shards from our parity (verified live:
under induced packet loss the picture recovers cleanly instead of failing with
"network connection too bad"; 0% added loss in normal operation).

The decisive finding: Moonlight's nanors uses a CAUCHY generator matrix
(M[j][i] = inv[(m+i)^j], GF(2^8) poly 0x1d), while reed-solomon-erasure is
Vandermonde — so its parity was NOT Moonlight-decodable, despite the old
gf8.rs comment claiming equivalence.

lumen-core:
- Swap the GF(2^8) backend from reed-solomon-erasure to a vendored fec-rs
  (vendor/fec-rs, BSD-2), which builds the byte-identical Cauchy matrix. Pure
  Rust, no FFI — keeps the "one core" hot path. This makes both lumen's own
  protocol and the GameStream parity nanors-compatible.
- Lock it with a regression test against real nanors vectors
  (k=4,m=2 [10,20,30,40] -> parity [136,0]) + an independent matrix-derived
  cross-check + an erase/recover round-trip. Existing FEC/loopback tests stay
  green, so lumen's own protocol is unaffected.

lumen-host video.rs:
- Generate m = ceil(k*pct/100) parity shards per FEC block via Gf8Coder; stamp
  fecInfo with the recomputed wire pct (100*m/k) so the client derives the same
  count; cap per-block data to 255*100/(100+pct) so k+m <= 255.
- CRITICAL byte-exactness: RS runs over the whole `blocksize` shard (Moonlight
  decodes packetSize+16 bytes from the datagram start and PACKET_RECOVERY_FAILUREs
  on a bad reconstructed `flags` byte). So the NV header fields RS must reproduce
  (streamPacketIndex/frameIndex/flags/multiFec*) are written into data shards
  BEFORE encode, and only the transport fields (RTP header/seq/timestamp +
  fecInfo) are stamped AFTER — leaving the flags byte RS-covered. Matches
  Sunshine stream.cpp. Unit-tested incl. flags recovery.
- fec_percentage wired from stream.rs (Sunshine default 20, LUMEN_FEC_PCT
  override; 0 = data-only). LUMEN_VIDEO_DROP injects loss to test recovery.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-09 11:34:27 +00:00
parent 278a6330de
commit 72f8c05aa3
14 changed files with 2921 additions and 212 deletions
+5 -1
View File
@@ -23,7 +23,11 @@ quic = ["dep:quinn", "dep:tokio"]
[dependencies]
reed-solomon-simd = "3.1" # GF(2^16) Leopard-RS, SIMD, O(n log n) — the wall-breaker (P2)
reed-solomon-erasure = "6.0" # GF(2^8) classic RS — GameStream/Moonlight compat (P1)
# Vendored fork of fec-rs: GF(2^8) classic RS with the *Cauchy* generator matrix
# (M[j][i] = inv[(m+i)^j]) — byte-identical to the `nanors` library Moonlight uses, so our
# parity is decodable by a stock Moonlight client. (reed-solomon-erasure is Vandermonde and is
# NOT interoperable.) See vendor/fec-rs/LICENSE (BSD-2-Clause).
fec-rs = { path = "vendor/fec-rs" }
aes-gcm = "0.10" # AES-128-GCM session crypto, matches GameStream
zerocopy = { version = "0.8", features = ["derive"] }
bytes = "1"
+73 -4
View File
@@ -1,9 +1,12 @@
//! GF(2⁸) classic ReedSolomon backend (`reed-solomon-erasure`), equivalent to the
//! `nanors` library Moonlight uses. Hard ceiling: data + recovery ≤ 255 shards/block.
//! GF(2⁸) classic ReedSolomon backend (vendored `fec-rs`). Uses the **Cauchy** generator
//! matrix `M[j][i] = inv[(m+i)^j]` over GF(2⁸) (poly 0x1d) — byte-identical to the `nanors`
//! library Moonlight uses, so the parity this produces is recoverable by a stock Moonlight
//! client (unlike Vandermonde RS, whose parity is not interoperable). Hard ceiling: data +
//! recovery ≤ 255 shards/block.
use super::{validate_block_shape, validate_encode_shape, ErasureCoder, FecError};
use crate::config::FecScheme;
use reed_solomon_erasure::galois_8::ReedSolomon;
use fec_rs::ReedSolomon;
pub struct Gf8Coder;
@@ -21,7 +24,7 @@ impl ErasureCoder for Gf8Coder {
let shard_len = data[0].len();
let rs = ReedSolomon::new(k, recovery_count)
.map_err(|_| FecError::Config("invalid GF(2^8) shard counts"))?;
// reed-solomon-erasure fills parity in place: shards = data || zeroed parity.
// fec-rs fills parity in place: shards = data || zeroed parity.
let mut shards: Vec<Vec<u8>> = Vec::with_capacity(k + recovery_count);
shards.extend_from_slice(data);
shards.resize_with(k + recovery_count, || vec![0u8; shard_len]);
@@ -69,3 +72,69 @@ fn collect_originals(
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
/// Locks byte-exact compatibility with Moonlight's `nanors` (Cauchy matrix
/// `M[j][i] = inv[(m+i)^j]`, GF(2⁸) poly 0x1d). If the backend ever switched matrices,
/// these vectors would break and our parity would no longer be Moonlight-decodable.
#[test]
fn nanors_exact_parity_vectors() {
let coder = Gf8Coder;
// The definitive nanors vector (k=4, m=2): single-byte shards [10,20,30,40] → [136, 0].
let data = vec![vec![10u8], vec![20], vec![30], vec![40]];
let parity = coder.encode(&data, 2).unwrap();
assert_eq!(parity, vec![vec![136u8], vec![0u8]]);
// Cross-check independently from the Cauchy parity rows (proves the matrix, not just a
// memorized output): parity[j] = XOR_i M[j][i] · data[i] over GF(2⁸).
let rows = [[142u8, 244, 71, 167], [244, 142, 167, 71]];
let din = [10u8, 20, 30, 40];
for (j, row) in rows.iter().enumerate() {
let expect = row
.iter()
.zip(din)
.fold(0u8, |acc, (&m, d)| acc ^ gf_mul(m, d));
assert_eq!(parity[j][0], expect, "parity row {j}");
}
}
/// Round-trip: erase `m` data shards and confirm reconstruction recovers the originals.
#[test]
fn recovers_erased_data_shards() {
let coder = Gf8Coder;
let data: Vec<Vec<u8>> = (0..6).map(|i| vec![i as u8; 8]).collect();
let parity = coder.encode(&data, 3).unwrap();
let mut received: Vec<Option<Vec<u8>>> = data
.iter()
.cloned()
.map(Some)
.chain(parity.into_iter().map(Some))
.collect();
// Erase 3 data shards (the FEC budget) + nothing else.
received[1] = None;
received[3] = None;
received[5] = None;
let recovered = coder.reconstruct(6, 3, &mut received).unwrap();
assert_eq!(recovered, data);
}
/// GF(2⁸) multiply, reduction poly 0x1d — independent of the backend.
fn gf_mul(mut a: u8, mut b: u8) -> u8 {
let mut p = 0u8;
for _ in 0..8 {
if b & 1 != 0 {
p ^= a;
}
let hi = a & 0x80;
a <<= 1;
if hi != 0 {
a ^= 0x1d;
}
b >>= 1;
}
p
}
}
+17
View File
@@ -0,0 +1,17 @@
[package]
name = "fec-rs"
version = "0.1.0"
edition = "2021"
description = "A pure Rust Reed-Solomon erasure coding library with runtime SIMD acceleration"
license = "BSD-2-Clause"
repository = "https://github.com/hgaiser/fec-rs"
keywords = ["reed-solomon", "erasure", "coding", "fec", "simd"]
categories = ["algorithms", "encoding"]
readme = "README.md"
[dependencies]
rayon = { version = "1", optional = true }
[features]
default = []
parallel = ["rayon"]
+24
View File
@@ -0,0 +1,24 @@
BSD 2-Clause License
Copyright (c) 2026, Hans Gaiser
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+73
View File
@@ -0,0 +1,73 @@
# fec-rs
[![CI](https://github.com/hgaiser/fec-rs/workflows/CI/badge.svg)](https://github.com/hgaiser/fec-rs/actions)
[![Crates.io](https://img.shields.io/crates/v/fec-rs.svg)](https://crates.io/crates/fec-rs)
[![Documentation](https://docs.rs/fec-rs/badge.svg)](https://docs.rs/fec-rs)
A pure Rust Reed-Solomon erasure coding library with runtime SIMD acceleration.
## Features
- **Pure Rust** — No C/C++ dependencies or FFI. Everything is implemented in safe Rust
(with targeted `unsafe` for SIMD intrinsics).
- **Runtime SIMD detection** — Automatically uses the fastest available instruction set
via `std::is_x86_feature_detected!`. A single binary works on all x86_64 systems.
- **GF(2^8)** — Operates over the Galois field GF(2^8) with generating polynomial 29 (0x1D),
compatible with the Moonlight streaming protocol.
- **Shard-by-shard encoding** — Incremental encoding via `ShardByShard` for streaming use cases.
- **Reconstruction** — Reconstruct missing data and/or parity shards from any sufficient subset.
## SIMD Acceleration
On x86_64, the library automatically detects CPU features at runtime and uses
the best available instruction set:
- **GFNI + AVX2** — Single-instruction GF multiply on 32 bytes (Intel Alder Lake+, AMD Zen 4+)
- **AVX2** — VPSHUFB split-table nibble lookup on 32 bytes
- **GFNI + SSE** — Single-instruction GF multiply on 16 bytes
- **SSSE3** — VPSHUFB split-table nibble lookup on 16 bytes
- **Scalar** — Lookup table fallback
## Parallel Encoding
Enable the `parallel` feature for optional rayon-based parallel encoding:
```toml
fec-rs = { version = "0.1", features = ["parallel"] }
```
When enabled, large encode workloads automatically distribute parity shard
computation across threads. Small workloads use the sequential path to avoid
overhead.
## Usage
```rust
use fec_rs::ReedSolomon;
let rs = ReedSolomon::new(4, 2).unwrap();
let mut shards: Vec<Vec<u8>> = vec![
vec![0, 1, 2, 3],
vec![4, 5, 6, 7],
vec![8, 9, 10, 11],
vec![12, 13, 14, 15],
vec![0, 0, 0, 0], // parity shard 1
vec![0, 0, 0, 0], // parity shard 2
];
// Encode parity
rs.encode(&mut shards).unwrap();
// Verify
assert!(rs.verify(&shards).unwrap());
// Simulate loss of shard 0
let mut recovery: Vec<Option<Vec<u8>>> = shards.into_iter().map(Some).collect();
recovery[0] = None;
// Reconstruct
rs.reconstruct(&mut recovery).unwrap();
```
License: BSD-2-Clause
+200
View File
@@ -0,0 +1,200 @@
#![allow(clippy::needless_range_loop)]
use std::env;
use std::fs::File;
use std::io::Write;
use std::path::Path;
const FIELD_SIZE: usize = 256;
const GENERATING_POLYNOMIAL: usize = 29;
fn gen_log_table(polynomial: usize) -> [u8; FIELD_SIZE] {
let mut result = [0u8; FIELD_SIZE];
let mut b: usize = 1;
for log in 0..FIELD_SIZE - 1 {
result[b] = log as u8;
b <<= 1;
if FIELD_SIZE <= b {
b = (b - FIELD_SIZE) ^ polynomial;
}
}
result
}
const EXP_TABLE_SIZE: usize = FIELD_SIZE * 2 - 2;
fn gen_exp_table(log_table: &[u8; FIELD_SIZE]) -> [u8; EXP_TABLE_SIZE] {
let mut result = [0u8; EXP_TABLE_SIZE];
for i in 1..FIELD_SIZE {
let log = log_table[i] as usize;
result[log] = i as u8;
result[log + FIELD_SIZE - 1] = i as u8;
}
result
}
fn multiply(log_table: &[u8; FIELD_SIZE], exp_table: &[u8; EXP_TABLE_SIZE], a: u8, b: u8) -> u8 {
if a == 0 || b == 0 {
0
} else {
let log_a = log_table[a as usize];
let log_b = log_table[b as usize];
let log_result = log_a as usize + log_b as usize;
exp_table[log_result]
}
}
fn gen_mul_table(
log_table: &[u8; FIELD_SIZE],
exp_table: &[u8; EXP_TABLE_SIZE],
) -> [[u8; FIELD_SIZE]; FIELD_SIZE] {
let mut result = [[0u8; FIELD_SIZE]; FIELD_SIZE];
for a in 0..FIELD_SIZE {
for b in 0..FIELD_SIZE {
result[a][b] = multiply(log_table, exp_table, a as u8, b as u8);
}
}
result
}
fn gen_mul_table_half(
log_table: &[u8; FIELD_SIZE],
exp_table: &[u8; EXP_TABLE_SIZE],
) -> ([[u8; 16]; FIELD_SIZE], [[u8; 16]; FIELD_SIZE]) {
let mut low = [[0u8; 16]; FIELD_SIZE];
let mut high = [[0u8; 16]; FIELD_SIZE];
for a in 0..FIELD_SIZE {
for b in 0..FIELD_SIZE {
let mut result = 0;
if a != 0 && b != 0 {
let log_a = log_table[a];
let log_b = log_table[b];
result = exp_table[log_a as usize + log_b as usize];
}
if (b & 0x0F) == b {
low[a][b] = result;
}
if (b & 0xF0) == b {
high[a][b >> 4] = result;
}
}
}
(low, high)
}
/// Generate the GFNI affine matrix table.
///
/// For each constant `c` in GF(2^8), compute a u64-packed 8x8 binary matrix
/// such that `vgf2p8affineqb(x, matrix, 0)` produces `c * x` in our GF(2^8).
///
/// vgf2p8affineqb semantics:
/// result_bit[i] = popcount(x AND qword_byte[7-i]) mod 2
/// where i goes from 0 (LSB) to 7 (MSB).
///
/// Matrix packing: qword byte[7] = row for output bit 7 (MSB),
/// qword byte[0] = row for output bit 0 (LSB).
fn gen_gfni_table(
log_table: &[u8; FIELD_SIZE],
exp_table: &[u8; EXP_TABLE_SIZE],
) -> [u64; FIELD_SIZE] {
let mut result = [0u64; FIELD_SIZE];
for c in 0..FIELD_SIZE {
// Build row bytes for each output bit.
// row_for_bit_i = mask where bit j is set iff input bit j contributes to output bit i.
// M[i][j] = bit_i(c * (1 << j))
let mut rows = [0u8; 8];
for j in 0..8u8 {
let basis = 1u8 << j; // input with only bit j set
let product = multiply(log_table, exp_table, c as u8, basis);
// product's bit i tells us M[i][j]
for i in 0..8u8 {
if (product >> i) & 1 == 1 {
rows[i as usize] |= 1 << j;
}
}
}
// Pack into u64: byte[7-i] = rows[i]
// vgf2p8affineqb: result_bit[i] = popcount(x AND byte[7-i]) mod 2
// We want result_bit[i] = bit i of (c*x), so byte[7-i] = rows[i].
let mut matrix: u64 = 0;
for i in 0..8u32 {
matrix |= (rows[i as usize] as u64) << ((7 - i) * 8);
}
result[c] = matrix;
}
result
}
fn write_1d_table(f: &mut File, table: &[u8], name: &str) {
let len = table.len();
write!(f, "pub static {name}: [u8; {len}] = [").unwrap();
for v in table {
write!(f, "{v}, ").unwrap();
}
writeln!(f, "];").unwrap();
}
fn write_2d_table(f: &mut File, table: &[[u8; 16]; FIELD_SIZE], name: &str) {
let rows = table.len();
let cols = table[0].len();
write!(f, "pub static {name}: [[u8; {cols}]; {rows}] = [").unwrap();
for row in table {
write!(f, "[").unwrap();
for v in row {
write!(f, "{v}, ").unwrap();
}
writeln!(f, "],").unwrap();
}
writeln!(f, "];").unwrap();
}
fn write_mul_table(f: &mut File, table: &[[u8; FIELD_SIZE]; FIELD_SIZE]) {
let rows = table.len();
let cols = table[0].len();
write!(f, "pub static MUL_TABLE: [[u8; {cols}]; {rows}] = [").unwrap();
for row in table {
write!(f, "[").unwrap();
for v in row {
write!(f, "{v}, ").unwrap();
}
writeln!(f, "],").unwrap();
}
writeln!(f, "];").unwrap();
}
fn write_gfni_table(f: &mut File, table: &[u64; FIELD_SIZE]) {
write!(f, "pub static GFNI_TABLE: [u64; {}] = [", FIELD_SIZE).unwrap();
for v in table {
write!(f, "0x{v:016X}, ").unwrap();
}
writeln!(f, "];").unwrap();
}
fn main() {
let log_table = gen_log_table(GENERATING_POLYNOMIAL);
let exp_table = gen_exp_table(&log_table);
let mul_table = gen_mul_table(&log_table, &exp_table);
let (mul_table_low, mul_table_high) = gen_mul_table_half(&log_table, &exp_table);
let gfni_table = gen_gfni_table(&log_table, &exp_table);
let out_dir = env::var("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("tables.rs");
let mut f = File::create(&dest_path).unwrap();
write_1d_table(&mut f, &log_table, "LOG_TABLE");
write_1d_table(&mut f, &exp_table, "EXP_TABLE");
write_mul_table(&mut f, &mul_table);
write_2d_table(&mut f, &mul_table_low, "MUL_TABLE_LOW");
write_2d_table(&mut f, &mul_table_high, "MUL_TABLE_HIGH");
write_gfni_table(&mut f, &gfni_table);
}
+61
View File
@@ -0,0 +1,61 @@
use core::fmt;
#[derive(PartialEq, Debug, Clone, Copy)]
pub enum Error {
TooFewShards,
TooManyShards,
TooFewDataShards,
TooManyDataShards,
TooFewParityShards,
TooManyParityShards,
TooFewBufferShards,
TooManyBufferShards,
IncorrectShardSize,
TooFewShardsPresent,
EmptyShard,
InvalidIndex,
InvalidParityMatrix,
SingularMatrix,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::TooFewShards => write!(f, "Too few shards"),
Error::TooManyShards => write!(f, "Too many shards"),
Error::TooFewDataShards => write!(f, "Too few data shards"),
Error::TooManyDataShards => write!(f, "Too many data shards"),
Error::TooFewParityShards => write!(f, "Too few parity shards"),
Error::TooManyParityShards => write!(f, "Too many parity shards"),
Error::TooFewBufferShards => write!(f, "Too few buffer shards"),
Error::TooManyBufferShards => write!(f, "Too many buffer shards"),
Error::IncorrectShardSize => write!(f, "Incorrect shard size"),
Error::TooFewShardsPresent => write!(f, "Too few shards present for reconstruction"),
Error::EmptyShard => write!(f, "Empty shard"),
Error::InvalidIndex => write!(f, "Invalid index"),
Error::InvalidParityMatrix => write!(f, "Invalid parity matrix"),
Error::SingularMatrix => write!(f, "Singular matrix"),
}
}
}
impl std::error::Error for Error {}
#[derive(PartialEq, Debug, Clone, Copy)]
pub enum SBSError {
TooManyCalls,
LeftoverShards,
RSError(Error),
}
impl fmt::Display for SBSError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SBSError::TooManyCalls => write!(f, "Too many calls"),
SBSError::LeftoverShards => write!(f, "Leftover shards"),
SBSError::RSError(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for SBSError {}
+636
View File
@@ -0,0 +1,636 @@
include!(concat!(env!("OUT_DIR"), "/tables.rs"));
/// Add two GF(2^8) elements (XOR).
#[inline(always)]
pub fn add(a: u8, b: u8) -> u8 {
a ^ b
}
/// Multiply two GF(2^8) elements using lookup table.
#[inline(always)]
pub fn mul(a: u8, b: u8) -> u8 {
MUL_TABLE[a as usize][b as usize]
}
/// Divide a by b in GF(2^8). Panics if b is 0.
#[inline(always)]
pub fn div(a: u8, b: u8) -> u8 {
if a == 0 {
return 0;
}
assert!(b != 0, "Division by zero in GF(2^8)");
let log_a = LOG_TABLE[a as usize] as isize;
let log_b = LOG_TABLE[b as usize] as isize;
let mut log_result = log_a - log_b;
if log_result < 0 {
log_result += 255;
}
EXP_TABLE[log_result as usize]
}
/// Compute a^n in GF(2^8).
#[inline(always)]
pub fn exp(a: u8, n: usize) -> u8 {
if n == 0 {
return 1;
}
if a == 0 {
return 0;
}
let log_a = LOG_TABLE[a as usize] as usize;
let log_result = log_a * (n % 255) % 255;
EXP_TABLE[log_result]
}
/// Multiply each element of `input` by `c` and write to `out`.
///
/// Uses SIMD acceleration when available:
/// - GFNI + AVX2 (best: single-instruction GF multiply on 32 bytes)
/// - AVX2 VPSHUFB (split-table nibble lookup on 32 bytes)
/// - GFNI + SSE (single-instruction GF multiply on 16 bytes)
/// - SSSE3 VPSHUFB (split-table nibble lookup on 16 bytes)
/// - Scalar fallback
#[inline]
pub fn mul_slice(c: u8, input: &[u8], out: &mut [u8]) {
assert_eq!(input.len(), out.len());
if input.is_empty() || c == 0 {
out.iter_mut().for_each(|o| *o = 0);
return;
}
if c == 1 {
out.copy_from_slice(input);
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("gfni") && is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_gfni_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("gfni") {
unsafe {
mul_slice_gfni_sse(c, input, out);
}
return;
}
if is_x86_feature_detected!("ssse3") {
unsafe {
mul_slice_ssse3(c, input, out);
}
return;
}
}
mul_slice_scalar(c, input, out);
}
/// Multiply each element of `input` by `c` and XOR into `out`.
///
/// Uses SIMD acceleration when available (same priority as `mul_slice`).
#[inline]
pub fn mul_slice_xor(c: u8, input: &[u8], out: &mut [u8]) {
assert_eq!(input.len(), out.len());
if input.is_empty() || c == 0 {
return;
}
if c == 1 {
for (o, i) in out.iter_mut().zip(input.iter()) {
*o ^= *i;
}
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("gfni") && is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_xor_gfni_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("avx2") {
unsafe {
mul_slice_xor_avx2(c, input, out);
}
return;
}
if is_x86_feature_detected!("gfni") {
unsafe {
mul_slice_xor_gfni_sse(c, input, out);
}
return;
}
if is_x86_feature_detected!("ssse3") {
unsafe {
mul_slice_xor_ssse3(c, input, out);
}
return;
}
}
mul_slice_xor_scalar(c, input, out);
}
/// Function pointer types for bulk GF(2^8) operations.
pub type MulSliceFn = fn(u8, &[u8], &mut [u8]);
/// Pair of (mul_slice, mul_slice_xor) function pointers for the best available SIMD path.
///
/// Unlike `mul_slice`/`mul_slice_xor`, these skip runtime feature detection on every call.
/// The caller checks once and stores the result.
///
/// Note: These raw dispatch functions do NOT handle the c==0 or c==1 special cases.
/// The caller must handle those before calling through the function pointer.
pub fn detect_mul_slice() -> (MulSliceFn, MulSliceFn) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("gfni") && is_x86_feature_detected!("avx2") {
return (
wrap_mul_slice_gfni_avx2 as MulSliceFn,
wrap_mul_slice_xor_gfni_avx2 as MulSliceFn,
);
}
if is_x86_feature_detected!("avx2") {
return (
wrap_mul_slice_avx2 as MulSliceFn,
wrap_mul_slice_xor_avx2 as MulSliceFn,
);
}
if is_x86_feature_detected!("gfni") {
return (
wrap_mul_slice_gfni_sse as MulSliceFn,
wrap_mul_slice_xor_gfni_sse as MulSliceFn,
);
}
if is_x86_feature_detected!("ssse3") {
return (
wrap_mul_slice_ssse3 as MulSliceFn,
wrap_mul_slice_xor_ssse3 as MulSliceFn,
);
}
}
(
mul_slice_scalar as MulSliceFn,
mul_slice_xor_scalar as MulSliceFn,
)
}
// Safe wrappers for SIMD functions (used as function pointer targets)
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_gfni_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_gfni_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_avx2(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_avx2(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_gfni_sse(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_gfni_sse(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_ssse3(c, input, out) }
}
#[cfg(target_arch = "x86_64")]
fn wrap_mul_slice_xor_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
unsafe { mul_slice_xor_ssse3(c, input, out) }
}
// ── Scalar fallback ──────────────────────────────────────────────────────
fn mul_slice_scalar(c: u8, input: &[u8], out: &mut [u8]) {
let mt = &MUL_TABLE[c as usize];
for (o, &i) in out.iter_mut().zip(input.iter()) {
*o = mt[i as usize];
}
}
fn mul_slice_xor_scalar(c: u8, input: &[u8], out: &mut [u8]) {
let mt = &MUL_TABLE[c as usize];
for (o, &i) in out.iter_mut().zip(input.iter()) {
*o ^= mt[i as usize];
}
}
// ── x86_64 SIMD implementations ─────────────────────────────────────────
// ── GFNI + AVX2 (best path: 32 bytes per vgf2p8affineqb) ──────────────
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni,avx2")]
unsafe fn mul_slice_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm256_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let result = _mm256_gf2p8affine_epi64_epi8(data, mat_vec, 0);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni,avx2")]
unsafe fn mul_slice_xor_gfni_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm256_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let existing = _mm256_loadu_si256(out.as_ptr().add(i) as *const _);
let mul_result = _mm256_gf2p8affine_epi64_epi8(data, mat_vec, 0);
let result = _mm256_xor_si256(mul_result, existing);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
// ── GFNI + SSE (16 bytes per vgf2p8affineqb) ──────────────────────────
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni")]
unsafe fn mul_slice_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let result = _mm_gf2p8affine_epi64_epi8(data, mat_vec, 0);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "gfni")]
unsafe fn mul_slice_xor_gfni_sse(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let matrix = GFNI_TABLE[c as usize] as i64;
let mat_vec = _mm_set1_epi64x(matrix);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let existing = _mm_loadu_si128(out.as_ptr().add(i) as *const _);
let mul_result = _mm_gf2p8affine_epi64_epi8(data, mat_vec, 0);
let result = _mm_xor_si128(mul_result, existing);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
// ── AVX2 VPSHUFB (32 bytes, split-table nibble lookup) ─────────────────
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn mul_slice_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
// Broadcast the 16-byte low/high tables to 256-bit registers by duplicating
let low_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(low.as_ptr() as *const _));
let high_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(high.as_ptr() as *const _));
let mask = _mm256_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
// Process 32 bytes at a time
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let lo_nibble = _mm256_and_si256(data, mask);
let hi_nibble = _mm256_and_si256(_mm256_srli_epi64(data, 4), mask);
let lo_result = _mm256_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm256_shuffle_epi8(high_vec, hi_nibble);
let result = _mm256_xor_si256(lo_result, hi_result);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
// Handle remaining bytes with scalar
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn mul_slice_xor_avx2(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
let low_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(low.as_ptr() as *const _));
let high_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(high.as_ptr() as *const _));
let mask = _mm256_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
while i + 32 <= len {
let data = _mm256_loadu_si256(input.as_ptr().add(i) as *const _);
let existing = _mm256_loadu_si256(out.as_ptr().add(i) as *const _);
let lo_nibble = _mm256_and_si256(data, mask);
let hi_nibble = _mm256_and_si256(_mm256_srli_epi64(data, 4), mask);
let lo_result = _mm256_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm256_shuffle_epi8(high_vec, hi_nibble);
let result = _mm256_xor_si256(_mm256_xor_si256(lo_result, hi_result), existing);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut _, result);
i += 32;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn mul_slice_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
let low_vec = _mm_loadu_si128(low.as_ptr() as *const _);
let high_vec = _mm_loadu_si128(high.as_ptr() as *const _);
let mask = _mm_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let lo_nibble = _mm_and_si128(data, mask);
let hi_nibble = _mm_and_si128(_mm_srli_epi64(data, 4), mask);
let lo_result = _mm_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm_shuffle_epi8(high_vec, hi_nibble);
let result = _mm_xor_si128(lo_result, hi_result);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) = mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn mul_slice_xor_ssse3(c: u8, input: &[u8], out: &mut [u8]) {
use core::arch::x86_64::*;
let low = &MUL_TABLE_LOW[c as usize];
let high = &MUL_TABLE_HIGH[c as usize];
let low_vec = _mm_loadu_si128(low.as_ptr() as *const _);
let high_vec = _mm_loadu_si128(high.as_ptr() as *const _);
let mask = _mm_set1_epi8(0x0F);
let len = input.len();
let mut i = 0;
while i + 16 <= len {
let data = _mm_loadu_si128(input.as_ptr().add(i) as *const _);
let existing = _mm_loadu_si128(out.as_ptr().add(i) as *const _);
let lo_nibble = _mm_and_si128(data, mask);
let hi_nibble = _mm_and_si128(_mm_srli_epi64(data, 4), mask);
let lo_result = _mm_shuffle_epi8(low_vec, lo_nibble);
let hi_result = _mm_shuffle_epi8(high_vec, hi_nibble);
let result = _mm_xor_si128(_mm_xor_si128(lo_result, hi_result), existing);
_mm_storeu_si128(out.as_mut_ptr().add(i) as *mut _, result);
i += 16;
}
let mt = &MUL_TABLE[c as usize];
while i < len {
*out.get_unchecked_mut(i) ^= mt[*input.get_unchecked(i) as usize];
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gfni_table() {
// Verify GFNI_TABLE by emulating vgf2p8affineqb in software:
// result_bit[i] = popcount(x AND qword_byte[7-i]) mod 2
for c in 0u16..256 {
let matrix = GFNI_TABLE[c as usize];
for b in 0u16..256 {
let expected = MUL_TABLE[c as usize][b as usize];
let x = b as u8;
let mut result: u8 = 0;
for i in 0..8u32 {
let row_byte = ((matrix >> ((7 - i) * 8)) & 0xFF) as u8;
let dot = (row_byte & x).count_ones() % 2;
result |= (dot as u8) << i;
}
assert_eq!(
result, expected,
"GFNI table mismatch: c={c}, b={b}, got={result}, expected={expected}"
);
}
}
}
#[test]
fn test_add() {
assert_eq!(add(0, 0), 0);
assert_eq!(add(1, 0), 1);
assert_eq!(add(0, 1), 1);
assert_eq!(add(1, 1), 0);
assert_eq!(add(0xFF, 0xFF), 0);
assert_eq!(add(0xAA, 0x55), 0xFF);
}
#[test]
fn test_mul() {
assert_eq!(mul(0, 0), 0);
assert_eq!(mul(1, 0), 0);
assert_eq!(mul(0, 1), 0);
assert_eq!(mul(1, 1), 1);
// a * 1 = a
for a in 0u8..=255 {
assert_eq!(mul(a, 1), a);
assert_eq!(mul(1, a), a);
}
// a * 0 = 0
for a in 0u8..=255 {
assert_eq!(mul(a, 0), 0);
}
}
#[test]
fn test_div() {
// a / 1 = a
for a in 0u8..=255 {
assert_eq!(div(a, 1), a);
}
// a / a = 1 (for a != 0)
for a in 1u8..=255 {
assert_eq!(div(a, a), 1);
}
// (a * b) / b = a
for a in 1u8..=255 {
for b in 1u8..=255 {
assert_eq!(div(mul(a, b), b), a);
}
}
}
#[test]
fn test_exp() {
assert_eq!(exp(0, 0), 1);
assert_eq!(exp(1, 0), 1);
assert_eq!(exp(5, 0), 1);
assert_eq!(exp(0, 1), 0);
assert_eq!(exp(0, 100), 0);
// a^1 = a
for a in 0u8..=255 {
assert_eq!(exp(a, 1), a);
}
// a^2 = a * a
for a in 0u8..=255 {
assert_eq!(exp(a, 2), mul(a, a));
}
}
#[test]
fn test_mul_slice_basic() {
let input = [1u8, 2, 3, 4, 5, 6, 7, 8];
let mut out = [0u8; 8];
mul_slice(3, &input, &mut out);
for i in 0..input.len() {
assert_eq!(out[i], mul(3, input[i]));
}
}
#[test]
fn test_mul_slice_xor_basic() {
let input = [1u8, 2, 3, 4, 5, 6, 7, 8];
let mut out = [10u8; 8];
let original = out;
mul_slice_xor(3, &input, &mut out);
for i in 0..input.len() {
assert_eq!(out[i], original[i] ^ mul(3, input[i]));
}
}
#[test]
fn test_mul_slice_large() {
// Test with a buffer large enough to exercise SIMD paths
let input: Vec<u8> = (0..256).map(|i| i as u8).collect();
let mut out = vec![0u8; 256];
let mut expected = vec![0u8; 256];
for c in [2u8, 7, 42, 128, 255] {
mul_slice_scalar(c, &input, &mut expected);
mul_slice(c, &input, &mut out);
assert_eq!(out, expected, "mul_slice mismatch for c={c}");
}
}
#[test]
fn test_mul_slice_xor_large() {
let input: Vec<u8> = (0..256).map(|i| i as u8).collect();
for c in [2u8, 7, 42, 128, 255] {
let mut out_expected = vec![0xABu8; 256];
let mut out_simd = out_expected.clone();
mul_slice_xor_scalar(c, &input, &mut out_expected);
mul_slice_xor(c, &input, &mut out_simd);
assert_eq!(out_simd, out_expected, "mul_slice_xor mismatch for c={c}");
}
}
#[test]
fn test_mul_slice_unaligned_sizes() {
// Test sizes that don't align to SIMD width
for size in [1, 7, 15, 16, 17, 31, 32, 33, 63, 64, 65, 100] {
let input: Vec<u8> = (0..size).map(|i| i as u8).collect();
let mut out = vec![0u8; size];
let mut expected = vec![0u8; size];
mul_slice_scalar(42, &input, &mut expected);
mul_slice(42, &input, &mut out);
assert_eq!(out, expected, "mul_slice mismatch for size={size}");
}
}
}
+73
View File
@@ -0,0 +1,73 @@
//! A pure Rust Reed-Solomon erasure coding library with runtime SIMD acceleration.
//!
//! # Features
//!
//! - **Pure Rust** — No C/C++ dependencies or FFI. Everything is implemented in safe Rust
//! (with targeted `unsafe` for SIMD intrinsics).
//! - **Runtime SIMD detection** — Automatically uses the fastest available instruction set
//! via `std::is_x86_feature_detected!`. A single binary works on all x86_64 systems.
//! - **GF(2^8)** — Operates over the Galois field GF(2^8) with generating polynomial 29 (0x1D),
//! compatible with the Moonlight streaming protocol.
//! - **Shard-by-shard encoding** — Incremental encoding via `ShardByShard` for streaming use cases.
//! - **Reconstruction** — Reconstruct missing data and/or parity shards from any sufficient subset.
//!
//! # SIMD Acceleration
//!
//! On x86_64, the library automatically detects CPU features at runtime and uses
//! the best available instruction set:
//!
//! - **GFNI + AVX2** — Single-instruction GF multiply on 32 bytes (Intel Alder Lake+, AMD Zen 4+)
//! - **AVX2** — VPSHUFB split-table nibble lookup on 32 bytes
//! - **GFNI + SSE** — Single-instruction GF multiply on 16 bytes
//! - **SSSE3** — VPSHUFB split-table nibble lookup on 16 bytes
//! - **Scalar** — Lookup table fallback
//!
//! # Parallel Encoding
//!
//! Enable the `parallel` feature for optional rayon-based parallel encoding:
//!
//! ```toml
//! fec-rs = { version = "0.1", features = ["parallel"] }
//! ```
//!
//! When enabled, large encode workloads automatically distribute parity shard
//! computation across threads. Small workloads use the sequential path to avoid
//! overhead.
//!
//! # Usage
//!
//! ```
//! use fec_rs::ReedSolomon;
//!
//! let rs = ReedSolomon::new(4, 2).unwrap();
//!
//! let mut shards: Vec<Vec<u8>> = vec![
//! vec![0, 1, 2, 3],
//! vec![4, 5, 6, 7],
//! vec![8, 9, 10, 11],
//! vec![12, 13, 14, 15],
//! vec![0, 0, 0, 0], // parity shard 1
//! vec![0, 0, 0, 0], // parity shard 2
//! ];
//!
//! // Encode parity
//! rs.encode(&mut shards).unwrap();
//!
//! // Verify
//! assert!(rs.verify(&shards).unwrap());
//!
//! // Simulate loss of shard 0
//! let mut recovery: Vec<Option<Vec<u8>>> = shards.into_iter().map(Some).collect();
//! recovery[0] = None;
//!
//! // Reconstruct
//! rs.reconstruct(&mut recovery).unwrap();
//! ```
mod errors;
pub mod galois;
mod matrix;
mod reed_solomon;
pub use errors::{Error, SBSError};
pub use reed_solomon::{ReconstructShard, ReedSolomon, ShardByShard};
+251
View File
@@ -0,0 +1,251 @@
use crate::galois;
#[derive(PartialEq, Debug, Clone)]
pub struct Matrix {
pub row_count: usize,
pub col_count: usize,
pub data: Vec<u8>,
}
impl Matrix {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
row_count: rows,
col_count: cols,
data: vec![0u8; rows * cols],
}
}
pub fn identity(size: usize) -> Self {
let mut m = Self::new(size, size);
for i in 0..size {
m.data[i * size + i] = 1;
}
m
}
pub fn vandermonde(rows: usize, cols: usize) -> Self {
let mut m = Self::new(rows, cols);
for r in 0..rows {
let r_a = r as u8;
for c in 0..cols {
m.data[r * cols + c] = galois::exp(r_a, c);
}
}
m
}
#[inline]
pub fn get(&self, r: usize, c: usize) -> u8 {
self.data[r * self.col_count + c]
}
#[inline]
pub fn set(&mut self, r: usize, c: usize, val: u8) {
self.data[r * self.col_count + c] = val;
}
pub fn get_row(&self, row: usize) -> &[u8] {
let start = row * self.col_count;
&self.data[start..start + self.col_count]
}
pub fn sub_matrix(&self, rmin: usize, cmin: usize, rmax: usize, cmax: usize) -> Self {
let new_rows = rmax - rmin;
let new_cols = cmax - cmin;
let mut m = Self::new(new_rows, new_cols);
for r in rmin..rmax {
for c in cmin..cmax {
m.data[(r - rmin) * new_cols + (c - cmin)] = self.get(r, c);
}
}
m
}
pub fn multiply(&self, rhs: &Matrix) -> Self {
assert_eq!(
self.col_count, rhs.row_count,
"Matrix dimensions incompatible for multiply"
);
let mut result = Self::new(self.row_count, rhs.col_count);
for r in 0..self.row_count {
for c in 0..rhs.col_count {
let mut val = 0u8;
for i in 0..self.col_count {
val = galois::add(val, galois::mul(self.get(r, i), rhs.get(i, c)));
}
result.set(r, c, val);
}
}
result
}
pub fn augment(&self, rhs: &Matrix) -> Self {
assert_eq!(
self.row_count, rhs.row_count,
"Matrix row counts must match for augment"
);
let new_cols = self.col_count + rhs.col_count;
let mut m = Self::new(self.row_count, new_cols);
for r in 0..self.row_count {
for c in 0..self.col_count {
m.set(r, c, self.get(r, c));
}
for c in 0..rhs.col_count {
m.set(r, self.col_count + c, rhs.get(r, c));
}
}
m
}
fn swap_rows(&mut self, r1: usize, r2: usize) {
if r1 == r2 {
return;
}
let s1 = r1 * self.col_count;
let s2 = r2 * self.col_count;
for i in 0..self.col_count {
self.data.swap(s1 + i, s2 + i);
}
}
fn gaussian_elim(&mut self) -> Result<(), &'static str> {
for r in 0..self.row_count {
// Pivot search
if self.get(r, r) == 0 {
for r_below in r + 1..self.row_count {
if self.get(r_below, r) != 0 {
self.swap_rows(r, r_below);
break;
}
}
}
if self.get(r, r) == 0 {
return Err("Singular matrix");
}
// Scale to 1
if self.get(r, r) != 1 {
let scale = galois::div(1, self.get(r, r));
for c in 0..self.col_count {
let val = galois::mul(scale, self.get(r, c));
self.set(r, c, val);
}
}
// Eliminate below
for r_below in r + 1..self.row_count {
if self.get(r_below, r) != 0 {
let scale = self.get(r_below, r);
for c in 0..self.col_count {
let val =
galois::add(self.get(r_below, c), galois::mul(scale, self.get(r, c)));
self.set(r_below, c, val);
}
}
}
}
// Back substitution
for d in 0..self.row_count {
for r_above in 0..d {
if self.get(r_above, d) != 0 {
let scale = self.get(r_above, d);
for c in 0..self.col_count {
let val =
galois::add(self.get(r_above, c), galois::mul(scale, self.get(d, c)));
self.set(r_above, c, val);
}
}
}
}
Ok(())
}
pub fn invert(&self) -> Result<Self, &'static str> {
assert!(
self.row_count == self.col_count,
"Cannot invert non-square matrix"
);
let mut work = self.augment(&Self::identity(self.row_count));
work.gaussian_elim()?;
Ok(work.sub_matrix(0, self.row_count, self.col_count, self.col_count * 2))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mat(data: Vec<Vec<u8>>) -> Matrix {
let rows = data.len();
let cols = data[0].len();
let flat: Vec<u8> = data.into_iter().flatten().collect();
Matrix {
row_count: rows,
col_count: cols,
data: flat,
}
}
#[test]
fn test_identity() {
let m = Matrix::identity(3);
let expected = mat(vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]]);
assert_eq!(m, expected);
}
#[test]
fn test_multiply() {
let m1 = mat(vec![vec![1, 2], vec![3, 4]]);
let m2 = mat(vec![vec![5, 6], vec![7, 8]]);
let result = m1.multiply(&m2);
let expected = mat(vec![vec![11, 22], vec![19, 42]]);
assert_eq!(result, expected);
}
#[test]
fn test_invert() {
let m = mat(vec![
vec![56, 23, 98],
vec![3, 100, 200],
vec![45, 201, 123],
]);
let inv = m.invert().unwrap();
let expected = mat(vec![
vec![175, 133, 33],
vec![130, 13, 245],
vec![112, 35, 126],
]);
assert_eq!(inv, expected);
}
#[test]
fn test_invert_identity() {
let m = Matrix::identity(4);
let inv = m.invert().unwrap();
assert_eq!(inv, m);
}
#[test]
fn test_multiply_identity() {
let m = mat(vec![
vec![56, 23, 98],
vec![3, 100, 200],
vec![45, 201, 123],
]);
let id = Matrix::identity(3);
assert_eq!(m.multiply(&id), m);
assert_eq!(id.multiply(&m), m);
}
#[test]
fn test_invert_times_original_is_identity() {
let m = mat(vec![
vec![56, 23, 98],
vec![3, 100, 200],
vec![45, 201, 123],
]);
let inv = m.invert().unwrap();
let product = m.multiply(&inv);
assert_eq!(product, Matrix::identity(3));
}
}
File diff suppressed because it is too large Load Diff