Commit 34bf4f87 authored by Sai Tarun Inaganti's avatar Sai Tarun Inaganti

No commit message

No commit message
parent dc9c56c1
...@@ -7,7 +7,8 @@ edition = "2018" ...@@ -7,7 +7,8 @@ edition = "2018"
[dependencies] [dependencies]
fftw = { version = "*" } fftw = { version = "*" }
ndarray = { version = "*" } ndarray = { version = "*", features = ["rayon"] }
num = { version = "*" } num = { version = "*" }
rand = { version = "*" } rand = { version = "*" }
rayon = { version = "*" }
strum = { version = "*" } strum = { version = "*" }
\ No newline at end of file
use std::fs::File;
use std::process::exit; use std::process::exit;
use fftw::types::*; // use fftw::types::*;
use ::fhew::{ use ::fhew::{
*, *,
BinGate::*, BinGate::*,
...@@ -9,8 +10,8 @@ use ::fhew::{ ...@@ -9,8 +10,8 @@ use ::fhew::{
use rand::Rng; use rand::Rng;
fn help(cmd: &String) { fn help(cmd: &String) {
eprintln!("\nusage: {} n\n", cmd); eprintln!("\nusage: {} <count>\n", cmd);
eprintln!(" Generate a secret key sk and evaluation key ek, and repeat the following test n times:"); eprintln!(" Generate a secret key sk and evaluation key ek, and repeat the following test <count> times:");
eprintln!(" - generate random bits b1,b2,b3,b4"); eprintln!(" - generate random bits b1,b2,b3,b4");
eprintln!(" - compute ciphertexts c1, c2, c3 and c4 encrypting b1, b2, b3 and b4 under sk"); eprintln!(" - compute ciphertexts c1, c2, c3 and c4 encrypting b1, b2, b3 and b4 under sk");
eprintln!(" - homomorphically compute the encrypted (c1 NAND c2) NAND (c3 NAND c4)"); eprintln!(" - homomorphically compute the encrypted (c1 NAND c2) NAND (c3 NAND c4)");
...@@ -28,19 +29,19 @@ fn cleartext_gate(v1: bool, v2: bool, gate: BinGate) -> bool { ...@@ -28,19 +29,19 @@ fn cleartext_gate(v1: bool, v2: bool, gate: BinGate) -> bool {
} }
fn eprint_gate(gate: BinGate) { fn eprint_gate(gate: BinGate) {
match gate { match gate {
OR => eprint!(" OR\t"), OR => eprint!("OR"),
AND => eprint!(" AND\t"), AND => eprint!("AND"),
NOR => eprint!(" NOT\t"), NOR => eprint!("NOR"),
NAND => eprint!(" NAND\t") NAND => eprint!("NAND")
} }
} }
fn main() { fn main() {
// assert_eq!(q, 512); // assert_eq!(q, 512);
let mut rng = rand::thread_rng(); // let mut rng = rand::thread_rng();
let mut ffto: FFT = Default::default(); let mut ffto: FFT = Default::default();
// fftSetup(&mut ffto); // fftSetup(&mut ffto);
let mut tTestMSB: RingFFT = Default::default(); let mut t_test_msb: RingFFT = RingFFT();
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
if args.len() != 2 { if args.len() != 2 {
...@@ -48,14 +49,19 @@ fn main() { ...@@ -48,14 +49,19 @@ fn main() {
} }
let count: i32 = args[1].parse().unwrap(); let count: i32 = args[1].parse().unwrap();
eprintln!("Setting up FHEW"); eprintln!("Setting up FHEW");
fhew::setup(&mut ffto, &mut tTestMSB); fhew::setup(&mut ffto, &mut t_test_msb);
eprint!("Generating secret key ... "); eprint!("Generating secret key ... ");
let mut lwe_sk: lwe::SecretKey = Default::default(); let mut lwe_sk: lwe::SecretKey = SecretKey();
lwe::keyGen(&mut lwe_sk, &mut rng); lwe::key_gen(&mut lwe_sk);
// dbg!(&lwe_sk);
eprintln!("Done.\n"); eprintln!("Done.\n");
eprintln!("Generating evaluation key ... this may take a while ... "); eprintln!("Generating evaluation key ... this may take a while ... ");
let mut ek: EvalKey = Default::default(); let mut ek: EvalKey = Default::default();
fhew::keyGen(&mut ek, &lwe_sk, &mut rng, &mut ffto); fhew::key_gen(&mut ek, &lwe_sk, &mut ffto);
// let mut f = File::create("key.txt").unwrap();
// fwrite_ek(&ek, &mut f);
eprintln!("Done.\n"); eprintln!("Done.\n");
eprintln!("Testing depth-2 homomorphic circuits {} times.", count); eprintln!("Testing depth-2 homomorphic circuits {} times.", count);
eprintln!("Circuit shape : (a GATE NOT(b)) GATE (c GATE d)\n"); eprintln!("Circuit shape : (a GATE NOT(b)) GATE (c GATE d)\n");
...@@ -65,48 +71,67 @@ fn main() { ...@@ -65,48 +71,67 @@ fn main() {
let (mut se1, mut se2, mut e1, mut e2, mut e12): (CipherText, CipherText, CipherText, CipherText, CipherText) let (mut se1, mut se2, mut e1, mut e2, mut e12): (CipherText, CipherText, CipherText, CipherText, CipherText)
= Default::default(); = Default::default();
for i in 1..(3*count) { for i in 1..(3*count+1) {
if i % 3 != 0 { // if i != 1 {break;}
v1 = rng.gen::<i32>() % 2; if i % 3 != 0 { // 1,2
v2 = rng.gen::<i32>() % 2; v1 = (rand::thread_rng().gen::<u32>() % 2) as i32;
lwe::encrypt(&mut e1, &lwe_sk, v1, &mut rng); v2 = (rand::thread_rng().gen::<u32>() % 2) as i32;
lwe::encrypt(&mut e2, &lwe_sk, v2, &mut rng); // v1 = 0;
if i % 3 == 1 { // v2 = 0;
eprint!(" NOT\tEnc({}) = ", v2); lwe::encrypt(&mut e1, &lwe_sk, v1);
lwe::encrypt(&mut e2, &lwe_sk, v2);
// for t in 0..n {
// println!("t = {}, lwe_sk[t] = {}, e1.a[t] = {}, e2.a[t] = {}", t, lwe_sk[t], e1.a[t], e2.a[t]);
// }
if i % 3 == 1 { // 1
eprint!("\tNOT\tEnc({}) = ", v2);
let e2_temp = e2.clone(); let e2_temp = e2.clone();
fhew::hom_not(&mut e2, &e2_temp); fhew::hom_not(&mut e2, &e2_temp);
let notv2 = lwe::decrypt(&lwe_sk, &e2); let notv2 = lwe::decrypt(&lwe_sk, &e2);
eprintln!("Enc({})", v2); eprintln!("Enc({})", notv2);
if !(notv2 == !v2) { // dbg!(v2,notv2,!v2,!notv2);
if !(notv2 != v2 && notv2 * v2 == 0) {
eprintln!("ERROR: incorrect NOT Homomorphic computation at iteration {}", i+1); eprintln!("ERROR: incorrect NOT Homomorphic computation at iteration {}", i+1);
exit(1); exit(1);
} }
v2 = !v2; v2 = if v2 == 0 {1} else {0};
} }
} else { } else { // 3
v1 = sv1; v1 = sv1;
v2 = sv2; v2 = sv2;
e1 = se1.clone(); e1 = se1.clone();
e2 = se2.clone(); e2 = se2.clone();
} }
let gate: BinGate = match rng.gen::<usize>() % 4 { // let gate: BinGate = BinGate::NAND;
let gate: BinGate = match rand::thread_rng().gen::<usize>() % 4 {
0 => BinGate::OR, 0 => BinGate::OR,
1 => BinGate::AND, 1 => BinGate::AND,
2 => BinGate::NOR, 2 => BinGate::NOR,
3 => BinGate::NAND, 3 => BinGate::NAND,
_ => BinGate::OR _ => BinGate::OR
}; };
eprint!("Enc({})", v1);
lwe::encrypt(&mut e1, &lwe_sk, v1);
lwe::encrypt(&mut e2, &lwe_sk, v2);
fhew::hom_gate(&mut e12, gate, &ek, &e1, &e2, &mut ffto, &t_test_msb);
let v12: i32 = lwe::decrypt(&lwe_sk, &e12);
eprint!("Enc({})\t", v1);
eprint_gate(gate); eprint_gate(gate);
eprint!("Enc({}) = ", v2); eprint!("\tEnc({}) = ", v2);
fhew::hom_gate(&mut e12, gate, &ek, &e1, &e2, &mut ffto, &tTestMSB, &mut rng); eprint!("Enc({})", v12);
let v12: i32 = lwe::decrypt(&lwe_sk, &e12); eprintln!("");
eprintln!("Enc({})", v12); // for j in 0..n {
// println!("i = {}, j = {}, e1.a[j] = {}, e2.a[j] = {}, e12.a[j] = {}", i, j, e1.a[j], e2.a[j], e12.a[j]);
// }
// println!("i = {}\ne1.a = {:?}\ne2.a = {:?}\ne12.a = {:?}", i, e1.a, e2.a, e12.a);
// println!("e1.b = {}, e2.b = {}, e12.b = {}", e1.b, e2.b, e12.b);
match i % 3 { match i % 3 {
0 => eprintln!(""),
1 => { 1 => {
sv1 = v12; sv1 = v12;
se1 = e12.clone(); se1 = e12.clone();
...@@ -114,10 +139,10 @@ fn main() { ...@@ -114,10 +139,10 @@ fn main() {
2 => { 2 => {
sv2 = v12; sv2 = v12;
se2 = e12.clone(); se2 = e12.clone();
} },
_ => () _ => eprintln!("")
} }
// println!("i = {}, v1 = {}, v2 = {}, v12 = {}", i, v1, v2, v12);
if cleartext_gate(v1 != 0, v2 != 0, gate) != (v12 != 0) { if cleartext_gate(v1 != 0, v2 != 0, gate) != (v12 != 0) {
eprintln!("\n ERROR: incorrect Homomorphic Gate computation at iteration {}", i+1); eprintln!("\n ERROR: incorrect Homomorphic Gate computation at iteration {}", i+1);
exit(1); exit(1);
......
This diff is collapsed.
...@@ -9,26 +9,26 @@ use ndarray::*; ...@@ -9,26 +9,26 @@ use ndarray::*;
use rand::Rng; use rand::Rng;
use std::num::Wrapping; use std::num::Wrapping;
pub const n: usize = 500; pub const n: usize = 10;
pub const N: usize = 1024; pub const N: usize = 1024;
pub const N2: usize = N/2; pub const N2: usize = N/2+1;
pub const K: usize = 3; pub const K: usize = 3; // K
pub const K2: usize = 6; pub const K2: usize = 6;
pub const Q: usize = 1 << 32; pub const Q: usize = 1 << 32; // Q
pub const q: usize = 512; pub const q: usize = 512;
pub const q2: usize = 256; pub const q2: usize = q/2;
type ZmodQ = Wrapping<i32>; pub type ZmodQ = Wrapping<i32>;
type UZmodQ = Wrapping<u32>; type UZmodQ = Wrapping<u32>;
const v: ZmodQ = Wrapping((1 << 29) + 1); const V: ZmodQ = Wrapping((1 << 29) + 1);
const v_inverse: ZmodQ = Wrapping(-536870911); // 3758096385; const V_INVERSE: ZmodQ = Wrapping(-536870911); // 3758096385; 1/V mod Q
const vgprime: [ZmodQ; 3] = [Wrapping(536870913), Wrapping(2048), Wrapping(4194304)]; // [v, v<<11, v<<22]; const VGPRIME: [ZmodQ; 3] = [Wrapping(V.0), Wrapping(V.0 << 11), Wrapping(V.0 << 22)]; // [V, V<<11, V<<22];
const g_bits: [isize; 3] = [11, 11, 10]; const G_BITS: [isize; 3] = [11, 11, 10];
const g_bits_32: [isize; 3] = [21, 21, 22]; const G_BITS_32: [isize; 3] = [21, 21, 22];
pub const KS_BASE: usize = 25; pub const KS_BASE: usize = 25;
pub const KS_EXP: usize = 7; pub const KS_EXP: usize = 7;
...@@ -46,22 +46,30 @@ pub const BS_BASE: usize = 23; ...@@ -46,22 +46,30 @@ pub const BS_BASE: usize = 23;
pub const BS_EXP: usize = 2; pub const BS_EXP: usize = 2;
pub const BS_TABLE: [usize; 2] = [1, 23]; pub const BS_TABLE: [usize; 2] = [1, 23];
#[derive(Clone)] // #[derive(Clone,Debug)]
pub struct RingModQ(pub Array1<ZmodQ>); // [ZmodQ; N]; // pub struct RingModQ(pub Array1<ZmodQ>); // [ZmodQ; N];
impl Default for RingModQ { // impl Default for RingModQ {
fn default() -> RingModQ { // fn default() -> RingModQ {
RingModQ(Array::default(N)) // RingModQ(Array::default(N))
} // }
// }
pub type RingModQ = Array<ZmodQ,Ix1>;
pub fn RingModQ() -> RingModQ {
Array::default(N)
} }
#[derive(Clone)] // #[derive(Clone,Debug)]
pub struct RingFFT(pub Array1<c64>); // [c64; N2]; // pub struct RingFFT(pub Array1<c64>); // [c64; N2];
impl Default for RingFFT { // impl Default for RingFFT {
fn default() -> RingFFT { // fn default() -> RingFFT {
RingFFT(Array::default(N2)) // RingFFT(Array::default(N2))
} // }
// }
pub type RingFFT = Array<c64,Ix1>;
pub fn RingFFT() -> RingFFT {
Array::default(N2)
} }
#[derive(Copy,Clone)] #[derive(Copy,Clone,Debug)]
pub enum BinGate { OR, AND, NOR, NAND } pub enum BinGate { OR, AND, NOR, NAND }
const GATE_CONST: [usize; 4] = [15*q/8, 9*q/8, 11*q/8, 13*q/8]; const GATE_CONST: [usize; 4] = [15*q/8, 9*q/8, 11*q/8, 13*q/8];
...@@ -75,11 +83,11 @@ struct Distrib { ...@@ -75,11 +83,11 @@ struct Distrib {
table: &'static [f64] table: &'static [f64]
} }
fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 { fn sample(chi: &Distrib) -> i32 {
// dbg!(chi.std_dev); // dbg!(chi.std_dev);
if chi.max != 0 { if chi.max != 0 { // CHI1, CHI_BINARY
// println!("path 1"); // println!("path 1");
let r: f64 = rng.gen(); let r: f64 = rand::thread_rng().gen();
for i in 0..chi.max { for i in 0..chi.max {
if r <= chi.table[i as usize] { if r <= chi.table[i as usize] {
return i - chi.offset; return i - chi.offset;
...@@ -89,24 +97,24 @@ fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 { ...@@ -89,24 +97,24 @@ fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 {
} }
let mut r: f64; let mut r: f64;
let s = chi.std_dev; let s = chi.std_dev;
if s < 500.0 { if s < 500.0 { // CHI3
// println!("path 2"); // println!("path 2");
let mut x: i32; let mut x: i32;
let maxx= (s*8.0).ceil() as i32; let maxx= (s*8.0).ceil() as i32;
loop { loop {
x = rng.gen::<i32>() % (2*maxx + 1) - maxx; x = rand::thread_rng().gen::<i32>() % (2*maxx + 1) - maxx;
r = rng.gen(); r = rand::thread_rng().gen();
// println!("x = {}, y = {}, z = {}", r, x, s); // println!("x = {}, y = {}, z = {}", r, x, s);
if r < (- (x*x) as f64 / (2.0*s*s)).exp() { if r < (- (x*x) as f64 / (2.0*s*s)).exp() {
return x; return x;
} }
} }
} else { } else { // CHI2
// println!("path 3"); // println!("path 3");
let mut x: f64; let mut x: f64;
loop { loop {
x = 16.0 * rng.gen::<f64>() - 8.0; x = 16.0 * rand::thread_rng().gen::<f64>() - 8.0;
r = rng.gen(); r = rand::thread_rng().gen();
// println!("r = {}\tx = {}\ts = {}", r, x, s); // println!("r = {}\tx = {}\ts = {}", r, x, s);
if r < (- x*x / 2.0).exp() { if r < (- x*x / 2.0).exp() {
return (0.5 + x*s).floor() as i32; return (0.5 + x*s).floor() as i32;
...@@ -174,30 +182,41 @@ impl Default for FFT { ...@@ -174,30 +182,41 @@ impl Default for FFT {
in_: AlignedVec::new(N*2), in_: AlignedVec::new(N*2),
out: AlignedVec::new(N+1), out: AlignedVec::new(N+1),
plan_fft_forw: R2CPlan64::aligned(&[N*2], Flag::PATIENT).unwrap(), plan_fft_forw: R2CPlan64::aligned(&[N*2], Flag::PATIENT).unwrap(),
plan_fft_back: C2RPlan64::aligned(&[N*2], Flag::PATIENT).unwrap() plan_fft_back: C2RPlan64::aligned(&[N*2], Flag::PATIENT | Flag::PRESERVEINPUT).unwrap()
} }
} }
} }
pub fn fft_setup(ffto: &mut FFT) { pub fn fft_setup(ffto: &mut FFT) {
*ffto = Default::default(); *ffto = Default::default();
} }
pub fn fft_forward(ffto: &mut FFT, val: &RingModQ, res: &mut RingFFT) { pub fn fft_forward<'a,'b>(ffto: &mut FFT, val: ArrayView<'a,ZmodQ,Ix1>, mut res: ArrayViewMut<'b,c64,Ix1>) {
for k in 0..N { for k in 0..N {
ffto.in_[k] = val.0[k].0 as f64; ffto.in_[k] = val[k].0 as f64;
ffto.in_[k+N] = 0.0; ffto.in_[k+N] = 0.0;
} }
ffto.plan_fft_forw.r2c(&mut ffto.in_, &mut ffto.out).unwrap(); ffto.plan_fft_forw.r2c(&mut ffto.in_, &mut ffto.out).unwrap();
for k in 0..N2 { for k in 0..(N2-1) {
res.0[k] = ffto.out[2*k+1]; res[k] = ffto.out[2*k+1];
// res[k] = ffto.out[k];
} }
} }
pub fn fft_backward(ffto: &mut FFT, val: &RingFFT, res: &mut RingModQ) { pub fn fft_backward<'a,'b>(ffto: &mut FFT, val: ArrayView<'a,c64,Ix1>, mut res: ArrayViewMut<'b,ZmodQ,Ix1>) {
for k in 0..N2 { for k in 0..N2 {
ffto.out[2*k+1] = val.0[k]/c64::new(N as f64,0.0);
ffto.out[2*k] = c64::new(0.0,0.0); ffto.out[2*k] = c64::new(0.0,0.0);
if k < N2 - 1 {
ffto.out[2*k+1] = val[k]/c64::new(N as f64,0.0);
}
// ffto.out[k] = val[k]; // /c64::new(N as f64,0.0);
} }
ffto.plan_fft_back.c2r(&mut ffto.out, &mut ffto.in_).unwrap(); ffto.plan_fft_back.c2r(&mut ffto.out, &mut ffto.in_).unwrap();
for k in 0..N { for k in 0..N {
res.0[k] = Wrapping(ffto.in_[k].round() as i32); // let max: i64 = i32::MAX as i64;
// let min: i64 = i32::MIN as i64;
// let div: i64 = max - min + 1;
// let mut t = ffto.in_[k].round() as i64 % div;
// t = if t > max { t - div } else if t < min { t + div } else { t };
// res[k] = Wrapping(t as i32);
res[k] = Wrapping(ffto.in_[k].round().rem_euclid(2f64.powi(32)) as u32 as i32);
} }
} }
use rand::Rng; use rand::Rng;
use crate::*; use crate::*;
use ndarray::parallel::prelude::*;
use ndarray::linalg::*;
use rayon::prelude::*;
#[derive(Clone)] #[derive(Clone,Debug)]
pub struct CipherText { pub struct CipherText {
pub a: Array1<i32>, // [isize; n], pub a: Array1<i32>, // [isize; n],
pub b: i32 pub b: i32
...@@ -14,7 +17,7 @@ impl Default for CipherText { ...@@ -14,7 +17,7 @@ impl Default for CipherText {
} }
} }
} }
#[derive(Clone)] #[derive(Clone,Debug)]
pub struct CipherTextQ { pub struct CipherTextQ {
pub a: Array1<ZmodQ>, // [ZmodQ; n], pub a: Array1<ZmodQ>, // [ZmodQ; n],
pub b: ZmodQ pub b: ZmodQ
...@@ -27,7 +30,7 @@ impl Default for CipherTextQ { ...@@ -27,7 +30,7 @@ impl Default for CipherTextQ {
} }
} }
} }
#[derive(Clone)] #[derive(Clone,Debug)]
pub struct CipherTextQN { pub struct CipherTextQN {
pub a: Array1<ZmodQ>, // [ZmodQ; n], pub a: Array1<ZmodQ>, // [ZmodQ; n],
pub b: ZmodQ pub b: ZmodQ
...@@ -35,123 +38,145 @@ pub struct CipherTextQN { ...@@ -35,123 +38,145 @@ pub struct CipherTextQN {
impl Default for CipherTextQN { impl Default for CipherTextQN {
fn default() -> Self { fn default() -> Self {
CipherTextQN { CipherTextQN {
a: Array::default(n), a: Array::default(N),
b: Default::default() b: Default::default()
} }
} }
} }
#[derive(Clone)] // #[derive(Clone,Debug)]
pub struct SecretKey(pub Array1<i32>); // [isize; n]; // pub struct SecretKey(pub Array1<i32>); // [isize; n];
impl Default for SecretKey { // impl Default for SecretKey {
fn default() -> Self { // fn default() -> Self {
SecretKey(Array::default(n)) // SecretKey(Array::default(n))
} // }
// }
pub type SecretKey = Array<i32,Ix1>;
pub fn SecretKey() -> SecretKey {
Array::default(n)
}
// #[derive(Clone,Debug)]
// pub struct SecretKeyN(pub Array1<i32>); // [isize; N];
// impl Default for SecretKeyN {
// fn default() -> Self {
// SecretKeyN(Array::default(N))
// }
// }
pub type SecretKeyN = Array<i32,Ix1>;
pub fn SecretKeyN() -> SecretKeyN {
Array::default(N)
} }
#[derive(Clone)]
pub struct SecretKeyN(pub Array1<i32>); // [isize; N]; // #[derive(Clone,Debug)]
impl Default for SecretKeyN { // pub struct SwitchingKey(pub Array3<CipherTextQ>); // [[[CipherTextQ; KS_EXP]; KS_BASE]; N];
fn default() -> Self { // impl Default for SwitchingKey {
SecretKeyN(Array::default(N)) // fn default() -> Self {
} // SwitchingKey(Array::default((N,KS_BASE,KS_EXP)))
// }
// }
pub type SwitchingKey = Array<CipherTextQ,Ix3>;
pub fn SwitchingKey() -> SwitchingKey {
let k = Array::default((N,KS_BASE,KS_EXP));
eprintln!("evaluation key 2 -> switching key zeroing complete");
k
} }
const qi: i32 = q as i32; const qi: i32 = q as i32;
pub fn keyGen(sk: &mut SecretKey, rng: &mut rand::rngs::ThreadRng) { pub fn key_gen(sk: &mut SecretKey) {
loop { // loop {
let mut s = 0; let mut s = 0;
let mut ss = 0; let mut ss = 0;
for i in 0..n { for i in 0..n {
sk.0[i] = sample(&CHI_BINARY, rng); // sk[i] = sample(&CHI_BINARY);
s += sk.0[i]; sk[i] = BINARY_TABLE[i%3].floor() as i32;
ss += (sk.0[i]).abs(); s += sk[i];
ss += (sk[i]).abs();
} }
if s.abs() > 5 || (ss - (n as i32) / 2).abs() > 5 { // if s.abs() > 5 || (ss - (n as i32) / 2).abs() > 5 {
continue; // continue;
} else { // } else {
break; // break;
} // }
} // }
} }
pub fn keyGenN(sk: &mut SecretKeyN, rng: &mut rand::rngs::ThreadRng) { pub fn key_gen_N(sk: &mut SecretKeyN) {
for i in 0..N { for i in 0..N {
sk.0[i] = sample(&CHI1, rng); // sk[i] = sample(&CHI1);
sk[i] = (i as i32) % 2;
} }
} }
pub fn encrypt(ct: &mut CipherText, sk: &SecretKey, m: i32, rng: &mut rand::rngs::ThreadRng) { pub fn encrypt(ct: &mut CipherText, sk: &SecretKey, m: i32) {
ct.b = (m % 4) * qi / 4 + sample(&CHI3, rng); ct.b = (m % 4) * qi / 4; // + sample(&CHI3);
for i in 0..n { for i in 0..n {
ct.a[i] = rng.gen::<i32>() % qi; // ct.a[i] = rand::thread_rng().gen::<i32>() % qi;
ct.b = (ct.b + ct.a[i] * sk.0[i]) % qi; ct.a[i] = (i as i32) % qi;
ct.b = (ct.b + ct.a[i] * sk[i]) % qi;
// println!("i = {}, ct.b = {}", i, ct.b);
} }
// println!("m = {}, ct.b = {}, ct->a = {:?}", m, ct.b, ct.a);
} }
pub fn decrypt(sk: &SecretKey, ct: &CipherText) -> i32 { pub fn decrypt(sk: &SecretKey, ct: &CipherText) -> i32 {
let mut r = ct.b; let mut r = ct.b;
// dbg!(r);
for i in 0..n { for i in 0..n {
r -= ct.a[i] * sk.0[i]; r -= ct.a[i] * sk[i];
// dbg!(r);
} }
r = ((r % qi) + qi + qi/8) % qi; r = ((r % qi) + qi + qi/8) % qi;
// dbg!(r,qi);
// dbg!(r, qi, 4*r/qi);
4 * r / qi 4 * r / qi
} }
#[derive(Clone)] pub fn switching_key_gen<'b,'c>(res: &mut SwitchingKey, new_sk: ArrayView<'b,i32,Ix1>, old_sk: ArrayView<'c,i32,Ix1>) {
pub struct SwitchingKey(pub Array3<CipherTextQ>); // [[[CipherTextQ; KS_EXP]; KS_BASE]; N]; // dbg!(&res[[0,0,0]]);
impl Default for SwitchingKey {
fn default() -> Self {
SwitchingKey(Array::default((N,KS_BASE,KS_EXP)))
}
}
pub fn switchingKeyGen(res: &mut SwitchingKey, new_sk: &SecretKey, old_sk: &SecretKeyN, rng: &mut rand::rngs::ThreadRng) {
for i in 0..N { for i in 0..N {
for j in 0..KS_BASE { for j in 0..KS_BASE {
for k in 0..KS_EXP { for k in 0..KS_EXP {
// dbg!("switching key",i,j,k);
let mut ct: CipherTextQ = Default::default(); let mut ct: CipherTextQ = Default::default();
ct.b = - Wrapping(old_sk.0[i]) * Wrapping(j as i32) * Wrapping(KS_TABLE[k]) + Wrapping(sample(&CHI2, rng)); // ct.a.par_map_inplace(|x| *x = rand::thread_rng().gen());
for l in 0..n { ct.a.par_map_inplace(|x| *x = Wrapping(i as i32) * Wrapping(j as i32) * Wrapping(k as i32));
ct.a[l] = rng.gen(); ct.b = // Wrapping(sample(&CHI2))
let addend = ct.a[l] * Wrapping(new_sk.0[l]); - Wrapping(old_sk[i]) * Wrapping(j as i32) * Wrapping(KS_TABLE[k])
// dbg!(ct.b, addend, isize::MAX, isize::MIN); + ct.a.dot::<Array<Wrapping<i32>,Ix1>>(&new_sk.mapv(Wrapping));
// let upper = addend <= isize::MAX - ct.b; // + Zip::from(&ct.a).and(new_sk).fold(Wrapping(0), |acc, &x, &y| acc + x * Wrapping(y));
// let lower = addend >= isize::MIN - ct.b; // + ct.a.par_iter().zip(new_sk.par_iter()).map(|(x,y)| x*Wrapping(y)).sum();
// ct.b = if upper && lower { res[[i,j,k]] = ct;
// ct.b + addend
// } else if !upper && lower {
// isize::MIN + (addend - (isize::MAX - ct.b)) - 1
// } else {
// isize::MAX - (- addend - (ct.b - isize::MIN)) + 1
// };
ct.b = ct.b + addend;
// dbg!(ct.b);
}
res.0[[i,j,k]] = ct;
} }
} }
} }
// dbg!(&res[[0,0,0]]);
eprintln!("evaluation key 3 -> switching key generation complete");
} }
pub fn keySwitch(res: &mut CipherTextQ, ksk: &SwitchingKey, ct: &CipherTextQN) { pub fn key_switch(res: &mut CipherTextQ, ksk: &SwitchingKey, ct: &CipherTextQN) {
for k in 0..n { for k in 0..n {
res.a[k] = Wrapping(0); res.a[k] = Wrapping(0);
// println!("res.a[{}] = {}", k, res.a[k]);
} }
res.b = ct.b; res.b = ct.b;
// dbg!(res.b);
for i in 0..N { for i in 0..N {
let a: UZmodQ = Wrapping(- ct.a[i].0 as u32); let mut a: UZmodQ = Wrapping(0) - Wrapping(ct.a[i].0 as u32);
for j in 0..KS_BASE { // println!("i = {}, a = {}", i, a);
for j in 0..KS_EXP {
let a0: UZmodQ = a % Wrapping(KS_BASE as u32); let a0: UZmodQ = a % Wrapping(KS_BASE as u32);
// print!("i = {}, j = {}, ksk[i][{}][j].a = [", i, j, a0.0);
for k in 0..n { for k in 0..n {
res.a[k] -= ksk.0[[i,a0.0 as usize,j]].a[k]; res.a[k] -= ksk[[i,a0.0 as usize,j]].a[k];
res.b -= ksk.0[[i,a0.0 as usize,j]].b; // print!("{},", ksk[[i,a0.0 as usize,j]].a[k]);
} }
// println!("]");
res.b -= ksk[[i,a0.0 as usize,j]].b;
a /= Wrapping(KS_BASE as u32);
} }
} }
} }
pub fn modSwitch(ct: &mut CipherText, c: &CipherTextQ, rng: &mut rand::rngs::ThreadRng) { pub fn mod_switch(ct: &mut CipherText, c: &CipherTextQ) {
for i in 0..n { ct.a = c.a.mapv(round_q_Q);
ct.a[i] = round_qQ(c.a[i]); ct.b = round_q_Q(c.b);
}
ct.b = round_qQ(c.b);
} }
pub fn round_qQ(v_: ZmodQ) -> i32 { pub fn round_q_Q(v: ZmodQ) -> i32 {
(0.5 + (v_.0 as f64) * (q as f64) / (Q as f64)).floor() as i32 (0.5 + (v.0 as f64) * (q as f64) / (Q as f64)).floor() as i32
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment