Commit 34bf4f87 authored by Sai Tarun Inaganti's avatar Sai Tarun Inaganti

No commit message

No commit message
parent dc9c56c1
......@@ -7,7 +7,8 @@ edition = "2018"
[dependencies]
fftw = { version = "*" }
ndarray = { version = "*" }
ndarray = { version = "*", features = ["rayon"] }
num = { version = "*" }
rand = { version = "*" }
rayon = { version = "*" }
strum = { version = "*" }
\ No newline at end of file
use std::fs::File;
use std::process::exit;
use fftw::types::*;
// use fftw::types::*;
use ::fhew::{
*,
BinGate::*,
......@@ -9,8 +10,8 @@ use ::fhew::{
use rand::Rng;
fn help(cmd: &String) {
eprintln!("\nusage: {} n\n", cmd);
eprintln!(" Generate a secret key sk and evaluation key ek, and repeat the following test n times:");
eprintln!("\nusage: {} <count>\n", cmd);
eprintln!(" Generate a secret key sk and evaluation key ek, and repeat the following test <count> times:");
eprintln!(" - generate random bits b1,b2,b3,b4");
eprintln!(" - compute ciphertexts c1, c2, c3 and c4 encrypting b1, b2, b3 and b4 under sk");
eprintln!(" - homomorphically compute the encrypted (c1 NAND c2) NAND (c3 NAND c4)");
......@@ -28,19 +29,19 @@ fn cleartext_gate(v1: bool, v2: bool, gate: BinGate) -> bool {
}
fn eprint_gate(gate: BinGate) {
match gate {
OR => eprint!(" OR\t"),
AND => eprint!(" AND\t"),
NOR => eprint!(" NOT\t"),
NAND => eprint!(" NAND\t")
OR => eprint!("OR"),
AND => eprint!("AND"),
NOR => eprint!("NOR"),
NAND => eprint!("NAND")
}
}
fn main() {
// assert_eq!(q, 512);
let mut rng = rand::thread_rng();
// let mut rng = rand::thread_rng();
let mut ffto: FFT = Default::default();
// fftSetup(&mut ffto);
let mut tTestMSB: RingFFT = Default::default();
let mut t_test_msb: RingFFT = RingFFT();
let args: Vec<String> = std::env::args().collect();
if args.len() != 2 {
......@@ -48,14 +49,19 @@ fn main() {
}
let count: i32 = args[1].parse().unwrap();
eprintln!("Setting up FHEW");
fhew::setup(&mut ffto, &mut tTestMSB);
fhew::setup(&mut ffto, &mut t_test_msb);
eprint!("Generating secret key ... ");
let mut lwe_sk: lwe::SecretKey = Default::default();
lwe::keyGen(&mut lwe_sk, &mut rng);
let mut lwe_sk: lwe::SecretKey = SecretKey();
lwe::key_gen(&mut lwe_sk);
// dbg!(&lwe_sk);
eprintln!("Done.\n");
eprintln!("Generating evaluation key ... this may take a while ... ");
let mut ek: EvalKey = Default::default();
fhew::keyGen(&mut ek, &lwe_sk, &mut rng, &mut ffto);
fhew::key_gen(&mut ek, &lwe_sk, &mut ffto);
// let mut f = File::create("key.txt").unwrap();
// fwrite_ek(&ek, &mut f);
eprintln!("Done.\n");
eprintln!("Testing depth-2 homomorphic circuits {} times.", count);
eprintln!("Circuit shape : (a GATE NOT(b)) GATE (c GATE d)\n");
......@@ -65,48 +71,67 @@ fn main() {
let (mut se1, mut se2, mut e1, mut e2, mut e12): (CipherText, CipherText, CipherText, CipherText, CipherText)
= Default::default();
for i in 1..(3*count) {
if i % 3 != 0 {
v1 = rng.gen::<i32>() % 2;
v2 = rng.gen::<i32>() % 2;
lwe::encrypt(&mut e1, &lwe_sk, v1, &mut rng);
lwe::encrypt(&mut e2, &lwe_sk, v2, &mut rng);
if i % 3 == 1 {
eprint!(" NOT\tEnc({}) = ", v2);
for i in 1..(3*count+1) {
// if i != 1 {break;}
if i % 3 != 0 { // 1,2
v1 = (rand::thread_rng().gen::<u32>() % 2) as i32;
v2 = (rand::thread_rng().gen::<u32>() % 2) as i32;
// v1 = 0;
// v2 = 0;
lwe::encrypt(&mut e1, &lwe_sk, v1);
lwe::encrypt(&mut e2, &lwe_sk, v2);
// for t in 0..n {
// println!("t = {}, lwe_sk[t] = {}, e1.a[t] = {}, e2.a[t] = {}", t, lwe_sk[t], e1.a[t], e2.a[t]);
// }
if i % 3 == 1 { // 1
eprint!("\tNOT\tEnc({}) = ", v2);
let e2_temp = e2.clone();
fhew::hom_not(&mut e2, &e2_temp);
let notv2 = lwe::decrypt(&lwe_sk, &e2);
eprintln!("Enc({})", v2);
if !(notv2 == !v2) {
eprintln!("Enc({})", notv2);
// dbg!(v2,notv2,!v2,!notv2);
if !(notv2 != v2 && notv2 * v2 == 0) {
eprintln!("ERROR: incorrect NOT Homomorphic computation at iteration {}", i+1);
exit(1);
}
v2 = !v2;
v2 = if v2 == 0 {1} else {0};
}
} else {
} else { // 3
v1 = sv1;
v2 = sv2;
e1 = se1.clone();
e2 = se2.clone();
}
let gate: BinGate = match rng.gen::<usize>() % 4 {
// let gate: BinGate = BinGate::NAND;
let gate: BinGate = match rand::thread_rng().gen::<usize>() % 4 {
0 => BinGate::OR,
1 => BinGate::AND,
2 => BinGate::NOR,
3 => BinGate::NAND,
_ => BinGate::OR
};
eprint!("Enc({})", v1);
eprint_gate(gate);
eprint!("Enc({}) = ", v2);
fhew::hom_gate(&mut e12, gate, &ek, &e1, &e2, &mut ffto, &tTestMSB, &mut rng);
lwe::encrypt(&mut e1, &lwe_sk, v1);
lwe::encrypt(&mut e2, &lwe_sk, v2);
fhew::hom_gate(&mut e12, gate, &ek, &e1, &e2, &mut ffto, &t_test_msb);
let v12: i32 = lwe::decrypt(&lwe_sk, &e12);
eprintln!("Enc({})", v12);
eprint!("Enc({})\t", v1);
eprint_gate(gate);
eprint!("\tEnc({}) = ", v2);
eprint!("Enc({})", v12);
eprintln!("");
// for j in 0..n {
// println!("i = {}, j = {}, e1.a[j] = {}, e2.a[j] = {}, e12.a[j] = {}", i, j, e1.a[j], e2.a[j], e12.a[j]);
// }
// println!("i = {}\ne1.a = {:?}\ne2.a = {:?}\ne12.a = {:?}", i, e1.a, e2.a, e12.a);
// println!("e1.b = {}, e2.b = {}, e12.b = {}", e1.b, e2.b, e12.b);
match i % 3 {
0 => eprintln!(""),
1 => {
sv1 = v12;
se1 = e12.clone();
......@@ -114,10 +139,10 @@ fn main() {
2 => {
sv2 = v12;
se2 = e12.clone();
},
_ => eprintln!("")
}
_ => ()
}
// println!("i = {}, v1 = {}, v2 = {}, v12 = {}", i, v1, v2, v12);
if cleartext_gate(v1 != 0, v2 != 0, gate) != (v12 != 0) {
eprintln!("\n ERROR: incorrect Homomorphic Gate computation at iteration {}", i+1);
exit(1);
......
......@@ -10,207 +10,373 @@ use std::io::{
BufRead,
Write
};
#[derive(Clone)]
struct CtModQ(Array1<CtModQ1>); // [CtModQ1; K2];
impl Default for CtModQ {
fn default() -> Self {
CtModQ(Array::default(K2))
}
// #[derive(Clone,Debug)]
// struct CtModQ(Array1<CtModQ1>); // [CtModQ1; K2];
// impl Default for CtModQ {
// fn default() -> Self {
// CtModQ(Array::default(K2))
// }
// }
pub type CtModQ = Array<ZmodQ,Ix3>;
pub fn CtModQ() -> CtModQ {
Array::default((K2,2,N))
}
#[derive(Clone)]
struct CtModQ1(Array1<RingModQ>); // [RingModQ; 2];
impl Default for CtModQ1 {
fn default() -> Self {
CtModQ1(Array::default(2))
}
// #[derive(Clone,Debug)]
// struct CtModQ1(Array1<RingModQ>); // [RingModQ; 2];
// impl Default for CtModQ1 {
// fn default() -> Self {
// CtModQ1(Array::default(2))
// }
// }
pub type CtModQ1 = Array<ZmodQ,Ix2>;
pub fn CtModQ1() -> CtModQ1 {
Array::default((2,N))
}
#[derive(Clone)]
struct DctModQ(Array1<DctModQ1>); // [DctModQ1; K2];
impl Default for DctModQ {
fn default() -> Self {
DctModQ(Array::default(K2))
}
// #[derive(Clone,Debug)]
// struct DctModQ(Array1<DctModQ1>); // [DctModQ1; K2];
// impl Default for DctModQ {
// fn default() -> Self {
// DctModQ(Array::default(K2))
// }
// }
pub type DctModQ = Array<ZmodQ,Ix3>;
pub fn DctModQ() -> DctModQ {
Array::default((K2,K2,N))
}
#[derive(Clone)]
struct DctModQ1(Array1<RingModQ>); // [RingModQ; K2];
impl Default for DctModQ1 {
fn default() -> Self {
DctModQ1(Array::default(K2))
}
// #[derive(Clone,Debug)]
// struct DctModQ1(Array1<RingModQ>); // [RingModQ; K2];
// impl Default for DctModQ1 {
// fn default() -> Self {
// DctModQ1(Array::default(K2))
// }
// }
pub type DctModQ1 = Array<ZmodQ,Ix2>;
pub fn DctModQ1() -> DctModQ1 {
Array::default((K2,N))
}
#[derive(Clone)]
struct DctFFT(Array1<DctFFT1>); // [DctFFT1; K2];
impl Default for DctFFT {
fn default() -> Self {
DctFFT(Array::default(K2))
}
// #[derive(Clone,Debug)]
// struct DctFFT(Array1<DctFFT1>); // [DctFFT1; K2];
// impl Default for DctFFT {
// fn default() -> Self {
// DctFFT(Array::default(K2))
// }
// }
pub type DctFFT = Array<c64,Ix3>;
pub fn DctFFT() -> DctFFT {
Array::default((2,K2,N2))
}
#[derive(Clone)]
struct DctFFT1(Array1<RingFFT>); // [RingFFT; K2];
impl Default for DctFFT1 {
fn default() -> Self {
DctFFT1(Array::default(K2))
}
// #[derive(Clone,Debug)]
// struct DctFFT1(Array1<RingFFT>); // [RingFFT; K2];
// impl Default for DctFFT1 {
// fn default() -> Self {
// DctFFT1(Array::default(K2))
// }
// }
pub type DctFFT1 = Array<c64,Ix2>;
pub fn DctFFT1() -> DctFFT1 {
Array::default((K2,N2))
}
#[derive(Clone)]
pub struct CtFFT(Array1<CtFFT1>); // [CtFFT1; K2];
impl Default for CtFFT {
fn default() -> Self {
CtFFT(Array::default(K2))
}
// #[derive(Clone,Debug)]
// pub struct CtFFT(Array1<CtFFT1>); // [CtFFT1; K2];
// impl Default for CtFFT {
// fn default() -> Self {
// CtFFT(Array::default(K2))
// }
// }
pub type CtFFT = Array<c64,Ix3>;
pub fn CtFFT() -> CtFFT {
Array::default((K2,2,N2))
}
#[derive(Clone)]
struct CtFFT1(Array1<RingFFT>); // [RingFFT; 2];
impl Default for CtFFT1 {
fn default() -> Self {
CtFFT1(Array::default(2))
}
// #[derive(Clone,Debug)]
// struct CtFFT1(Array1<RingFFT>); // [RingFFT; 2];
// impl Default for CtFFT1 {
// fn default() -> Self {
// CtFFT1(Array::default(2))
// }
// }
pub type CtFFT1 = Array<c64,Ix2>;
pub fn CtFFT1() -> CtFFT1 {
Array::default((2,N2))
}
#[derive(Clone)]
pub struct BootstrappingKey(Array3<CtFFT>); // [[[CtFFT; BS_EXP]; BS_BASE]; n];
impl Default for BootstrappingKey {
fn default() -> Self {
BootstrappingKey(Array::default((n,BS_BASE,BS_EXP)))
}
// #[derive(Clone,Debug)]
// pub struct BootstrappingKey(Array3<CtFFT>); // [[[CtFFT; BS_EXP]; BS_BASE]; n];
// impl Default for BootstrappingKey {
// fn default() -> Self {
// BootstrappingKey(Array::default((n,BS_BASE,BS_EXP)))
// }
// }
pub type BootstrappingKey = Array<c64,Ix6>;
pub fn BootstrappingKey() -> BootstrappingKey {
let k = Array::default((n,BS_BASE,BS_EXP,K2,2,N2));
eprintln!("evaluation key 1 -> bootstrapping key zeroing complete");
k
}
#[derive(Default,Clone)]
#[derive(Clone,Debug)]
pub struct EvalKey {
pub BSkey: BootstrappingKey,
pub KSkey: lwe::SwitchingKey
pub bskey: BootstrappingKey,
pub kskey: lwe::SwitchingKey
}
impl Default for EvalKey {
fn default() -> Self {
EvalKey {
bskey: BootstrappingKey(),
kskey: SwitchingKey()
}
}
}
pub fn setup(ffto: &mut FFT, tTestMSB: &mut RingFFT) {
pub fn setup(ffto: &mut FFT, t_test_msb: &mut RingFFT) {
fft_setup(ffto);
let mut tmsb: RingModQ = Default::default();
tmsb.0[0] = Wrapping(-1);
let mut tmsb: RingModQ = RingModQ();
tmsb[0] = Wrapping(-1);
for i in 1..N {
tmsb.0[i] = Wrapping(1);
tmsb[i] = Wrapping(1);
}
fft_forward(ffto, &tmsb, tTestMSB);
fft_forward(ffto, tmsb.view(), t_test_msb.view_mut());
// for i in 0..N2 { // no issue here
// println!("t_test_msb[{}] = {}, {}", i, t_test_msb[i].re, t_test_msb[i].im);
// }
}
fn fhewEncrypt(ct: &mut CtFFT, skFFT: &RingFFT, m: i32, rng: &mut rand::rngs::ThreadRng, ffto: &mut FFT) {
let mut ai: RingFFT = Default::default();
let mut res: CtModQ = Default::default();
fn fhew_encrypt<'a>(mut ct: ArrayViewMut<'a,c64,Ix3>, sk_fft: &RingFFT, m: i32, ffto: &mut FFT) {
let mut ai: RingFFT = RingFFT();
let mut res: CtModQ = CtModQ();
let qi = q as i32;
let Ni = N as i32;
let mut mm: i32 = (((m & qi) + qi) % qi) * (2*Ni/qi);
let mut mm: i32 = (((m % qi) + qi) % qi) * (2*Ni/qi);
let old_mm = mm;
let mut sign: i32 = 1;
let old_sign = sign;
if mm >= Ni {
mm -= Ni;
sign = -1;
}
// println!("m = {}, old_mm = {}, new_mm = {}, old_sign = {}, new_sign = {}", m, old_mm, mm, old_sign, sign);
// println!("stage 1");
for i in 0..K2 {
for k in 0..(Ni as usize) {
res.0[i].0[0].0[k] = rng.gen();
}
fft_forward(ffto, &((res.0[i]).0[0]), &mut ai);
for k in 0..(N2 as usize) {
ai.0[k] = ai.0[k] * skFFT.0[k];
}
fft_backward(ffto, &ai, &mut res.0[i].0[1]);
for k in 0..(Ni as usize) {
// println!("i = {}, k = {}, res = {}", i, k, res.0[i].0[1].0[k]);
res.0[i].0[1].0[k] += Wrapping(sample(&CHI1, rng) as i32);
}
// res[[i,0,k]] = rand::thread_rng().gen();
res[[i,0,k]] = Wrapping(k as i32);
// println!("res[[{},0,{}]] = {}", i, k, res[[i,0,k]]);
}
fft_forward(ffto, res.slice::<Ix1>(s![i,0,..]), ai.view_mut());
for k in 0..(N2 as usize) { // minute differences, mostly because of precision
// print!("i = {}, k = {}, ai_old[k] = ({}, {}), , sk_fft[k] = ({}, {}), ", i, k, ai[k].re, ai[k].im, sk_fft[k].re, sk_fft[k].im);
ai[k] = ai[k] * sk_fft[k];
// println!("ai_new[k] = ({}, {})", ai[k].re, ai[k].im);
}
fft_backward(ffto, ai.view(), res.slice_mut::<Ix1>(s![i,1,..]));
// for k in 0..(Ni as usize) {
// res[[i,1,k]] += Wrapping(sample(&CHI1) as i32);
// }
// for k in 0..(Ni as usize) { // res is clean
// println!("i = {}, k = {}, res[i][0][k] = {}, res[i][1][k] = {}", i, k,
// res[[i,0,k]], res[[i,1,k]]);
// }
}
// println!("stage 2");
for i in 0..K {
res.0[2*i].0[0].0[mm as usize] += Wrapping(sign) * vgprime[i];
res.0[2*i+1].0[1].0[mm as usize] += Wrapping(sign) * vgprime[i];
res[[2*i,0,mm as usize]] += Wrapping(sign) * VGPRIME[i];
res[[2*i+1,1,mm as usize]] += Wrapping(sign) * VGPRIME[i];
// println!("i = {}, mm = {}, res[2*i+0][0][{}] = {}, res[2*i+1][1][{}] = {}", i, mm, mm as usize, res[[2*i,0,mm as usize]], mm as usize, res[[2*i+1,1,mm as usize]]);
}
// println!("stage 3");
for i in 0..K2 {
for j in 0..2 {
fft_forward(ffto, &res.0[i].0[j], &mut ct.0[i].0[j]);
fft_forward(ffto, res.slice::<Ix1>(s![i,j,..]), ct.slice_mut::<Ix1>(s![i,j,..]));
// for k in 0..N2 {
// ct[[i,j,k]] = c64::new(ct[[i,j,k]].re.round(), ct[[i,j,k]].im.round());
// }
}
}
// println!("stage 4");
}
pub fn keyGen(ek: &mut EvalKey, lweSk: &lwe::SecretKey, rng: &mut rand::rngs::ThreadRng, ffto: &mut FFT) {
let mut fhewSK: SecretKeyN = Default::default();
keyGenN(&mut fhewSK, rng);
switchingKeyGen(&mut ek.KSkey, &lweSk, &fhewSK, rng);
pub fn key_gen(ek: &mut EvalKey, lwe_sk: &lwe::SecretKey, ffto: &mut FFT) {
let mut fhew_sk: SecretKeyN = SecretKeyN();
key_gen_N(&mut fhew_sk);
switching_key_gen(&mut ek.kskey, lwe_sk.view(), fhew_sk.view());
let mut fhewSkFFT: RingFFT = Default::default();
let mut fhew_rmq: RingModQ = Default::default();
for i in 0..N {
fhew_rmq.0[i] = Wrapping(fhewSK.0[i]);
}
fft_forward(ffto, &fhew_rmq, &mut fhewSkFFT);
let mut fhew_sk_fft: RingFFT = RingFFT();
let fhew_rmq: RingModQ = fhew_sk.mapv(Wrapping);
fft_forward(ffto, fhew_rmq.view(), fhew_sk_fft.view_mut());
// for i in 0..N2 {
// println!("fhew_sk_fft[{}] = ({}, {})", i, fhew_sk_fft[i].re, fhew_sk_fft[i].im);
// }
for i in 0..n {
for j in 1..BS_BASE {
for k in 0..BS_EXP {
ek.BSkey.0[[i,j,k]] = Default::default();
fhewEncrypt(&mut ek.BSkey.0[[i,j,k]], &fhewSkFFT, lweSk.0[i]*(j as i32)*(BS_TABLE[k] as i32), rng, ffto);
// dbg!("bootstrapping key",i,j,k);
// println!("i0 = {}, i1 = {}, i2 = {}", i, j, k);
fhew_encrypt(ek.bskey.slice_mut::<Ix3>(s![i,j,k,..,..,..]), &fhew_sk_fft, lwe_sk[i]*(j as i32)*(BS_TABLE[k] as i32), ffto); // wrong at 2,4,0,0,0,0
// if i == 2 && j == 4 && k == 0 {
// println!("lwe_sk[i] = {}, j = {}, BS_TABLE[k] = {}, product = {}",
// lwe_sk[i], j as i32, BS_TABLE[k] as i32, lwe_sk[i]*(j as i32)*(BS_TABLE[k] as i32));
// }
}
}
}
// for i in 0..n { // bootstrapping key clear, except for precision issues
// for j in 1..BS_BASE {
// for k in 0..BS_EXP {
// for x in 0..K2 {
// for y in 0..2 {
// for z in 0..N2 {
// let bskey = &ek.bskey[[i,j,k,x,y,z]];
// println!("bskey[{}][{}][{}][{}][{}][{}] = ({}, {})", i, j, k, x, y, z, bskey.re, bskey.im);
// }}}}}}
// for i in 0..N { // switching key is clear
// for j in 0..KS_BASE {
// for k in 0..KS_EXP {
// let kskey = &ek.kskey[[i,j,k]];
// print!("i = {}, j = {}, k = {}, kskey.b = {}, kskey.a = [", i, j, k, kskey.b);
// for t in 0..n {print!("{}, ", kskey.a[t]);}
// print!("]\n");
// }}}
eprintln!("evaluation key 4 -> bootstrapping key generation complete");
}
fn addToAcc(acc: &mut CtFFT1, c: &CtFFT, ffto: &mut FFT) {
let mut ct: CtModQ1 = Default::default();
let mut dct: DctModQ1 = Default::default();
let mut dctFFT: DctFFT1 = Default::default();
fn add_to_acc<'a,'b>(mut acc: ArrayViewMut<'a,c64,Ix2>, c: ArrayView<'b,c64,Ix3>, ffto: &mut FFT) {
let mut ct: CtModQ1 = CtModQ1();
let mut dct: DctModQ1 = DctModQ1();
let mut dct_fft: DctFFT1 = DctFFT1();
for j in 0..2 {
fft_backward(ffto, &acc.0[j], &mut ct.0[j]);
fft_backward(ffto, acc.slice(s![j,..]), ct.slice_mut(s![j,..]));
// for k in 0..N {
// print!("j = {}, k = {}, ct[j][k] = {}\t", j, k, ct[[j,k]]);
// if k < N2 {
// print!("acc[j][k] = ({}, {})", acc[[j,k]].re, acc[[j,k]].im);
// }
// println!("");
// }
}
for j in 0..2 {
for k in 0..N {
let mut t: ZmodQ = ct.0[j].0[k] * v_inverse;
let mut t: ZmodQ = ct[[j,k]] * V_INVERSE;
// println!("j = {}, k = {}, t = {}, ct[[j,k]] = {}, V_INVERSE = {}", j, k, t, ct[[j,k]], V_INVERSE);
for l in 0..K {
let r: ZmodQ = Wrapping((t.0 << g_bits_32[l]) >> g_bits_32[l]);
t = Wrapping((t - r).0 >> g_bits[l]);
dct.0[j+2*1].0[k] = r;
let t_temp = t.clone();
let r: ZmodQ = Wrapping((t.0 << G_BITS_32[l]) >> G_BITS_32[l]);
t = Wrapping((t - r).0 >> G_BITS[l]);
dct[[j+2*l,k]] = r;
// println!("j = {}, k = {}, l = {}, ct[j][k] = {}, dct[j+2*l][k] = {}, t_old = {}, r = {}, t_new = {}",
// j, k, l, ct[[j,k]], dct[[j+2*l,k]], t_temp, r, t);
}
}
}
for j in 0..K2 {
fft_forward(ffto, &dct.0[j], &mut dctFFT.0[j]);
fft_forward(ffto, dct.slice(s![j,..]), dct_fft.slice_mut(s![j,..]));
}
for j in 0..2 {
for k in 0..N2 {
acc.0[j].0[k] = c64::new(0.0,0.0);
acc[[j,k]] = c64::default();
for l in 0..K2 {
acc.0[j].0[k] += dctFFT.0[l].0[k] * c.0[l].0[j].0[k];
acc[[j,k]] += dct_fft[[l,k]] * c[[l,j,k]]; // here
// println!("l = {}, j = {}, k = {}, dct_fft[l][k] = ({:.0}, {:.0}), c[l][j][k] = ({:.0}, {:.0})", l, j, k,
// dct_fft[[l,k]].re, dct_fft[[l,k]].im, c[[l,j,k]].re, c[[l,j,k]].im);
}
}
// if j == 0 {
// for k in 0..N {
// print!("j = {}, k = {}, ct[j][k] = {}\t", j, k, ct[[j,k]]);
// if k < N2 {
// print!("acc[j][k] = {}", acc[[j,k]]);
// }
// println!("");
// }
// }
}
}
fn initializeAcc(acc: &mut CtFFT1, m: i32, ffto: &mut FFT) {
let mut res: CtModQ1 = Default::default();
fn initialize_acc(acc: &mut CtFFT1, m: i32, ffto: &mut FFT) {
let mut res: CtModQ1 = CtModQ1();
let qi = q as i32;
let Ni = N as i32;
let mut mm = (((m % qi) + qi) % qi) * (2*Ni/qi);
let mut sign = 1;
let old_mm = mm;
let mut sign: i32 = 1;
let old_sign = sign;
if mm >= Ni {
mm -= Ni;
sign = -1;
}
// println!("m = {}, old_mm = {}, new_mm = {}, old_sign = {}, new_sign = {}", m, old_mm, mm, old_sign, sign);
for j in 0..2 {
for k in 0..N {
res.0[j].0[k] = Wrapping(0);
}
res[[j,k]] = Wrapping(0);
}
res.0[1].0[mm as usize] += Wrapping(sign) * vgprime[0];
for j in 0..2 {
fft_forward(ffto, &res.0[j], &mut acc.0[j]);
}
res[[1,mm as usize]] += Wrapping(sign) * VGPRIME[0];
// for j in 0..2 {
// for k in 0..N {
// println!("j = {}, k = {}, res[j][k] = {}", j, k, res[[j,k]]);
// }
// }
for j in 0..2 { // this is the culprit
fft_forward(ffto, res.slice(s![j,..]), acc.slice_mut(s![j,..]));
}
// for j in 0..2 {
// for k in 0..N2 {
// println!("j = {}, k = {}, acc[j][k] = ({}, {})", j, k, acc[[j,k]].re, acc[[j,k]].im);
// }
// }
}
fn memberTest(t: &RingFFT, c: &CtFFT1, ffto: &mut FFT) -> CipherTextQN {
let mut temp: RingFFT = Default::default();
let mut tempModQ: RingModQ = Default::default();
fn member_test(t: &RingFFT, c: &CtFFT1, ffto: &mut FFT) -> CipherTextQN {
let mut temp: RingFFT = RingFFT();
let mut temp_mod_q: RingModQ = RingModQ();
let mut ct: CipherTextQN = Default::default();
// for i in 0..N2 {
// dbg!(i,t[[i]]);
// }
// for i in 0..2 {
// for j in 0..N2 {
// dbg!(i,j,c[[i,j]]);
// }
// }
for i in 0..N2 {
temp.0[i] = (c.0[0].0[i] * t.0[i]).conj();
}
fft_backward(ffto, &temp, &mut tempModQ);
temp[i] = (c[[0,i]] * t[i]).conj();
// print!("i = {}, temp[i] = {}, c[1][j] = {}, t[i] = {}\n", i, temp[i], c[[0,i]], t[i]);
}
// for i in 0..N2 {
// dbg!(i,temp[[i]]);
// }
// dbg!(&temp);
fft_backward(ffto, temp.view(), temp_mod_q.view_mut());
// for i in 0..N {
// dbg!(i,temp_mod_q[[i]]);
// }
for i in 0..N {
ct.a[i] = tempModQ.0[i];
ct.a[i] = temp_mod_q[i];
}
for i in 0..N2 {
temp.0[i] = c.0[1].0[i] * t.0[i];
}
fft_backward(ffto, &temp, &mut tempModQ);
ct.b = v + tempModQ.0[0];
temp[i] = c[[1,i]] * t[i];
// dbg!(i,temp[i],c[[1,i]],t[i]);
}
// for i in 0..N2 {
// dbg!(i,temp[[i]]);
// }
fft_backward(ffto, temp.view(), temp_mod_q.view_mut());
// for i in 0..N {
// dbg!(i,temp_mod_q[[i]]);
// }
ct.b = V + temp_mod_q[0];
// dbg!(ct.b,V,temp_mod_q[0]);
ct
}
pub fn hom_gate(
......@@ -220,39 +386,59 @@ pub fn hom_gate(
ct1: &CipherText,
ct2: &CipherText,
ffto: &mut FFT,
tTestMSB: &RingFFT,
rng: &mut rand::rngs::ThreadRng
t_test_msb: &RingFFT
) {
let mut e12: CipherText = Default::default();
let qi = q as i32;
for i in 0..n {
if ((ct1.a[i] - ct2.a[i]) % qi) != 0 && ((ct1.a[i] + ct2.a[i]) % qi) != 0 {
break;
}
if i == n - 1 {
panic!("ERROR: Please only use independant ciphertexts as inputs.");
}
}
// for i in 0..n {
// if ((ct1.a[i] - ct2.a[i]) % qi) != 0 && ((ct1.a[i] + ct2.a[i]) % qi) != 0 {
// break;
// }
// if i == n - 1 {
// panic!("ERROR: Please only use independant ciphertexts as inputs.");
// }
// }
for i in 0..n {
e12.a[i] = (2*qi - (ct1.a[i] + ct2.a[i])) % qi;
}
e12.b = (GATE_CONST[gate as usize] as i32) - (ct1.b + ct2.b) % qi;
let mut acc: CtFFT1 = Default::default();
initializeAcc(&mut acc, (e12.b + qi/4) % qi, ffto);
e12.b = (GATE_CONST[gate as usize] as i32) - ((ct1.b + ct2.b) % qi);
// println!("e12.b = {}, e12.a = {}", e12.b, e12.a);
// println!("ct1.b = {}, ct1.a = {}", ct1.b, ct1.a);
// println!("ct2.b = {}, ct2.a = {}", ct2.b, ct2.a);
let mut acc: CtFFT1 = CtFFT1();
initialize_acc(&mut acc, (e12.b + qi/4) % qi, ffto); // all clear
// for i in 0..2 {
// for j in 0..N2 {
// println!("acc[{}][{}] = {}, {}", i, j, acc[[i,j]].re, acc[[i,j]].im);
// }
// }
for i in 0..n {
let mut a = (qi - e12.a[i] % qi) % qi;
for k in 0..BS_EXP {
let a0 = a % (BS_BASE as i32);
// println!("i = {}, a0 = {}, k = {}", i, a0, k); // all clear
if a0 != 0 {
addToAcc(&mut acc, &ek.BSkey.0[[i,a0 as usize,k]], ffto);
add_to_acc(acc.view_mut(), ek.bskey.slice(s![i,a0 as usize,k,..,..,..]), ffto); // trouble here
}
a /= BS_BASE as i32;
}
}
let eQN: CipherTextQN = memberTest(tTestMSB, &acc, ffto);
let mut eQ: CipherTextQ = Default::default();
keySwitch(&mut eQ, &ek.KSkey, &eQN);
modSwitch(res, &eQ, rng);
// for i in 0..2 {
// for j in 0..N2 {
// println!("acc[{}][{}] = ({}, {})", i, j, acc[[i,j]].re, acc[[i,j]].im);
// }
// }
let e_qn: CipherTextQN = member_test(t_test_msb, &acc, ffto);
// println!("e_qn.b = {}, e_qn.a = {}", e_qn.b, e_qn.a);
let mut e_q: CipherTextQ = Default::default();
key_switch(&mut e_q, &ek.kskey, &e_qn);
// print!("e_q.b = {}, e_q.a = [", e_q.b);
// for t in 0..n {println!("{}, ", e_q.a[t]);}
// print!("]\n");
mod_switch(res, &e_q);
// dbg!(res);
}
pub fn hom_nand(
res: &mut CipherText,
......@@ -260,9 +446,8 @@ pub fn hom_nand(
ct1: &CipherText,
ct2: &CipherText,
ffto: &mut FFT,
tTestMSB: &RingFFT,
rng: &mut rand::rngs::ThreadRng) {
hom_gate(res, NAND, ek, ct1, ct2, ffto, tTestMSB, rng)
t_test_msb: &RingFFT) {
hom_gate(res, NAND, ek, ct1, ct2, ffto, t_test_msb)
}
pub fn hom_not(res: &mut CipherText, ct: &CipherText) {
let qi = q as i32;
......@@ -272,14 +457,14 @@ pub fn hom_not(res: &mut CipherText, ct: &CipherText) {
res.b = (5 * qi / 4 - ct.b) % qi;
}
pub fn fwriteEK(ek: &EvalKey, f: &mut File) {
pub fn fwrite_ek(ek: &EvalKey, f: &mut File) {
for i in 0..n {
for j in 1..BS_BASE {
for k in 0..BS_EXP {
for l in 0..K2 {
for m in 0..2 {
for n_ in 0..N2 {
writeln!(*f, "{}", ek.BSkey.0[[i,j,k]].0[l].0[m].0[n_]).unwrap();
writeln!(*f, "{}", ek.bskey[[i,j,k,l,m,n_]]).unwrap();
}
}
}
......@@ -290,15 +475,15 @@ pub fn fwriteEK(ek: &EvalKey, f: &mut File) {
for j in 0..KS_BASE {
for k in 0..KS_EXP {
for l in 0..n {
writeln!(*f, "{}", ek.KSkey.0[[i,j,k]].a[l]).unwrap();
writeln!(*f, "{}", ek.kskey[[i,j,k]].a[l]).unwrap();
}
writeln!(*f, "{}", ek.KSkey.0[[i,j,k]].b).unwrap();
writeln!(*f, "{}", ek.kskey[[i,j,k]].b).unwrap();
}
}
}
// f.sync_all().unwrap();
}
pub fn freadEK(f: &File) -> EvalKey {
pub fn fread_ek(f: &File) -> EvalKey {
use std::str::FromStr;
let mut ek: EvalKey = Default::default();
let mut b = std::io::BufReader::new(f);
......@@ -310,7 +495,7 @@ pub fn freadEK(f: &File) -> EvalKey {
for m in 0..2 {
for n_ in 0..N2 {
b.read_line(&mut s).unwrap();
ek.BSkey.0[[i,j,k]].0[l].0[m].0[n_] = c64::from_str(&s).unwrap();
ek.bskey[[i,j,k,l,m,n_]] = c64::from_str(&s).unwrap();
}
}
}
......@@ -322,10 +507,10 @@ pub fn freadEK(f: &File) -> EvalKey {
for k in 0..KS_EXP {
for l in 0..n {
b.read_line(&mut s).unwrap();
ek.KSkey.0[[i,j,k]].a[l] = Wrapping(i32::from_str(&s).unwrap());
ek.kskey[[i,j,k]].a[l] = Wrapping(i32::from_str(&s).unwrap());
}
b.read_line(&mut s).unwrap();
ek.KSkey.0[[i,j,k]].b = Wrapping(i32::from_str(&s).unwrap());
ek.kskey[[i,j,k]].b = Wrapping(i32::from_str(&s).unwrap());
}
}
}
......
......@@ -9,26 +9,26 @@ use ndarray::*;
use rand::Rng;
use std::num::Wrapping;
pub const n: usize = 500;
pub const n: usize = 10;
pub const N: usize = 1024;
pub const N2: usize = N/2;
pub const N2: usize = N/2+1;
pub const K: usize = 3;
pub const K: usize = 3; // K
pub const K2: usize = 6;
pub const Q: usize = 1 << 32;
pub const Q: usize = 1 << 32; // Q
pub const q: usize = 512;
pub const q2: usize = 256;
pub const q2: usize = q/2;
type ZmodQ = Wrapping<i32>;
pub type ZmodQ = Wrapping<i32>;
type UZmodQ = Wrapping<u32>;
const v: ZmodQ = Wrapping((1 << 29) + 1);
const v_inverse: ZmodQ = Wrapping(-536870911); // 3758096385;
const V: ZmodQ = Wrapping((1 << 29) + 1);
const V_INVERSE: ZmodQ = Wrapping(-536870911); // 3758096385; 1/V mod Q
const vgprime: [ZmodQ; 3] = [Wrapping(536870913), Wrapping(2048), Wrapping(4194304)]; // [v, v<<11, v<<22];
const g_bits: [isize; 3] = [11, 11, 10];
const g_bits_32: [isize; 3] = [21, 21, 22];
const VGPRIME: [ZmodQ; 3] = [Wrapping(V.0), Wrapping(V.0 << 11), Wrapping(V.0 << 22)]; // [V, V<<11, V<<22];
const G_BITS: [isize; 3] = [11, 11, 10];
const G_BITS_32: [isize; 3] = [21, 21, 22];
pub const KS_BASE: usize = 25;
pub const KS_EXP: usize = 7;
......@@ -46,22 +46,30 @@ pub const BS_BASE: usize = 23;
pub const BS_EXP: usize = 2;
pub const BS_TABLE: [usize; 2] = [1, 23];
#[derive(Clone)]
pub struct RingModQ(pub Array1<ZmodQ>); // [ZmodQ; N];
impl Default for RingModQ {
fn default() -> RingModQ {
RingModQ(Array::default(N))
}
// #[derive(Clone,Debug)]
// pub struct RingModQ(pub Array1<ZmodQ>); // [ZmodQ; N];
// impl Default for RingModQ {
// fn default() -> RingModQ {
// RingModQ(Array::default(N))
// }
// }
pub type RingModQ = Array<ZmodQ,Ix1>;
pub fn RingModQ() -> RingModQ {
Array::default(N)
}
#[derive(Clone)]
pub struct RingFFT(pub Array1<c64>); // [c64; N2];
impl Default for RingFFT {
fn default() -> RingFFT {
RingFFT(Array::default(N2))
}
// #[derive(Clone,Debug)]
// pub struct RingFFT(pub Array1<c64>); // [c64; N2];
// impl Default for RingFFT {
// fn default() -> RingFFT {
// RingFFT(Array::default(N2))
// }
// }
pub type RingFFT = Array<c64,Ix1>;
pub fn RingFFT() -> RingFFT {
Array::default(N2)
}
#[derive(Copy,Clone)]
#[derive(Copy,Clone,Debug)]
pub enum BinGate { OR, AND, NOR, NAND }
const GATE_CONST: [usize; 4] = [15*q/8, 9*q/8, 11*q/8, 13*q/8];
......@@ -75,11 +83,11 @@ struct Distrib {
table: &'static [f64]
}
fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 {
fn sample(chi: &Distrib) -> i32 {
// dbg!(chi.std_dev);
if chi.max != 0 {
if chi.max != 0 { // CHI1, CHI_BINARY
// println!("path 1");
let r: f64 = rng.gen();
let r: f64 = rand::thread_rng().gen();
for i in 0..chi.max {
if r <= chi.table[i as usize] {
return i - chi.offset;
......@@ -89,24 +97,24 @@ fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 {
}
let mut r: f64;
let s = chi.std_dev;
if s < 500.0 {
if s < 500.0 { // CHI3
// println!("path 2");
let mut x: i32;
let maxx= (s*8.0).ceil() as i32;
loop {
x = rng.gen::<i32>() % (2*maxx + 1) - maxx;
r = rng.gen();
x = rand::thread_rng().gen::<i32>() % (2*maxx + 1) - maxx;
r = rand::thread_rng().gen();
// println!("x = {}, y = {}, z = {}", r, x, s);
if r < (- (x*x) as f64 / (2.0*s*s)).exp() {
return x;
}
}
} else {
} else { // CHI2
// println!("path 3");
let mut x: f64;
loop {
x = 16.0 * rng.gen::<f64>() - 8.0;
r = rng.gen();
x = 16.0 * rand::thread_rng().gen::<f64>() - 8.0;
r = rand::thread_rng().gen();
// println!("r = {}\tx = {}\ts = {}", r, x, s);
if r < (- x*x / 2.0).exp() {
return (0.5 + x*s).floor() as i32;
......@@ -174,30 +182,41 @@ impl Default for FFT {
in_: AlignedVec::new(N*2),
out: AlignedVec::new(N+1),
plan_fft_forw: R2CPlan64::aligned(&[N*2], Flag::PATIENT).unwrap(),
plan_fft_back: C2RPlan64::aligned(&[N*2], Flag::PATIENT).unwrap()
plan_fft_back: C2RPlan64::aligned(&[N*2], Flag::PATIENT | Flag::PRESERVEINPUT).unwrap()
}
}
}
pub fn fft_setup(ffto: &mut FFT) {
*ffto = Default::default();
}
pub fn fft_forward(ffto: &mut FFT, val: &RingModQ, res: &mut RingFFT) {
pub fn fft_forward<'a,'b>(ffto: &mut FFT, val: ArrayView<'a,ZmodQ,Ix1>, mut res: ArrayViewMut<'b,c64,Ix1>) {
for k in 0..N {
ffto.in_[k] = val.0[k].0 as f64;
ffto.in_[k] = val[k].0 as f64;
ffto.in_[k+N] = 0.0;
}
ffto.plan_fft_forw.r2c(&mut ffto.in_, &mut ffto.out).unwrap();
for k in 0..N2 {
res.0[k] = ffto.out[2*k+1];
for k in 0..(N2-1) {
res[k] = ffto.out[2*k+1];
// res[k] = ffto.out[k];
}
}
pub fn fft_backward(ffto: &mut FFT, val: &RingFFT, res: &mut RingModQ) {
pub fn fft_backward<'a,'b>(ffto: &mut FFT, val: ArrayView<'a,c64,Ix1>, mut res: ArrayViewMut<'b,ZmodQ,Ix1>) {
for k in 0..N2 {
ffto.out[2*k+1] = val.0[k]/c64::new(N as f64,0.0);
ffto.out[2*k] = c64::new(0.0,0.0);
if k < N2 - 1 {
ffto.out[2*k+1] = val[k]/c64::new(N as f64,0.0);
}
// ffto.out[k] = val[k]; // /c64::new(N as f64,0.0);
}
ffto.plan_fft_back.c2r(&mut ffto.out, &mut ffto.in_).unwrap();
for k in 0..N {
res.0[k] = Wrapping(ffto.in_[k].round() as i32);
// let max: i64 = i32::MAX as i64;
// let min: i64 = i32::MIN as i64;
// let div: i64 = max - min + 1;
// let mut t = ffto.in_[k].round() as i64 % div;
// t = if t > max { t - div } else if t < min { t + div } else { t };
// res[k] = Wrapping(t as i32);
res[k] = Wrapping(ffto.in_[k].round().rem_euclid(2f64.powi(32)) as u32 as i32);
}
}
use rand::Rng;
use crate::*;
use ndarray::parallel::prelude::*;
use ndarray::linalg::*;
use rayon::prelude::*;
#[derive(Clone)]
#[derive(Clone,Debug)]
pub struct CipherText {
pub a: Array1<i32>, // [isize; n],
pub b: i32
......@@ -14,7 +17,7 @@ impl Default for CipherText {
}
}
}
#[derive(Clone)]
#[derive(Clone,Debug)]
pub struct CipherTextQ {
pub a: Array1<ZmodQ>, // [ZmodQ; n],
pub b: ZmodQ
......@@ -27,7 +30,7 @@ impl Default for CipherTextQ {
}
}
}
#[derive(Clone)]
#[derive(Clone,Debug)]
pub struct CipherTextQN {
pub a: Array1<ZmodQ>, // [ZmodQ; n],
pub b: ZmodQ
......@@ -35,123 +38,145 @@ pub struct CipherTextQN {
impl Default for CipherTextQN {
fn default() -> Self {
CipherTextQN {
a: Array::default(n),
a: Array::default(N),
b: Default::default()
}
}
}
#[derive(Clone)]
pub struct SecretKey(pub Array1<i32>); // [isize; n];
impl Default for SecretKey {
fn default() -> Self {
SecretKey(Array::default(n))
}
// #[derive(Clone,Debug)]
// pub struct SecretKey(pub Array1<i32>); // [isize; n];
// impl Default for SecretKey {
// fn default() -> Self {
// SecretKey(Array::default(n))
// }
// }
pub type SecretKey = Array<i32,Ix1>;
pub fn SecretKey() -> SecretKey {
Array::default(n)
}
// #[derive(Clone,Debug)]
// pub struct SecretKeyN(pub Array1<i32>); // [isize; N];
// impl Default for SecretKeyN {
// fn default() -> Self {
// SecretKeyN(Array::default(N))
// }
// }
pub type SecretKeyN = Array<i32,Ix1>;
pub fn SecretKeyN() -> SecretKeyN {
Array::default(N)
}
#[derive(Clone)]
pub struct SecretKeyN(pub Array1<i32>); // [isize; N];
impl Default for SecretKeyN {
fn default() -> Self {
SecretKeyN(Array::default(N))
}
// #[derive(Clone,Debug)]
// pub struct SwitchingKey(pub Array3<CipherTextQ>); // [[[CipherTextQ; KS_EXP]; KS_BASE]; N];
// impl Default for SwitchingKey {
// fn default() -> Self {
// SwitchingKey(Array::default((N,KS_BASE,KS_EXP)))
// }
// }
pub type SwitchingKey = Array<CipherTextQ,Ix3>;
pub fn SwitchingKey() -> SwitchingKey {
let k = Array::default((N,KS_BASE,KS_EXP));
eprintln!("evaluation key 2 -> switching key zeroing complete");
k
}
const qi: i32 = q as i32;
pub fn keyGen(sk: &mut SecretKey, rng: &mut rand::rngs::ThreadRng) {
loop {
pub fn key_gen(sk: &mut SecretKey) {
// loop {
let mut s = 0;
let mut ss = 0;
for i in 0..n {
sk.0[i] = sample(&CHI_BINARY, rng);
s += sk.0[i];
ss += (sk.0[i]).abs();
}
if s.abs() > 5 || (ss - (n as i32) / 2).abs() > 5 {
continue;
} else {
break;
}
// sk[i] = sample(&CHI_BINARY);
sk[i] = BINARY_TABLE[i%3].floor() as i32;
s += sk[i];
ss += (sk[i]).abs();
}
// if s.abs() > 5 || (ss - (n as i32) / 2).abs() > 5 {
// continue;
// } else {
// break;
// }
// }
}
pub fn keyGenN(sk: &mut SecretKeyN, rng: &mut rand::rngs::ThreadRng) {
pub fn key_gen_N(sk: &mut SecretKeyN) {
for i in 0..N {
sk.0[i] = sample(&CHI1, rng);
// sk[i] = sample(&CHI1);
sk[i] = (i as i32) % 2;
}
}
pub fn encrypt(ct: &mut CipherText, sk: &SecretKey, m: i32, rng: &mut rand::rngs::ThreadRng) {
ct.b = (m % 4) * qi / 4 + sample(&CHI3, rng);
pub fn encrypt(ct: &mut CipherText, sk: &SecretKey, m: i32) {
ct.b = (m % 4) * qi / 4; // + sample(&CHI3);
for i in 0..n {
ct.a[i] = rng.gen::<i32>() % qi;
ct.b = (ct.b + ct.a[i] * sk.0[i]) % qi;
// ct.a[i] = rand::thread_rng().gen::<i32>() % qi;
ct.a[i] = (i as i32) % qi;
ct.b = (ct.b + ct.a[i] * sk[i]) % qi;
// println!("i = {}, ct.b = {}", i, ct.b);
}
// println!("m = {}, ct.b = {}, ct->a = {:?}", m, ct.b, ct.a);
}
pub fn decrypt(sk: &SecretKey, ct: &CipherText) -> i32 {
let mut r = ct.b;
// dbg!(r);
for i in 0..n {
r -= ct.a[i] * sk.0[i];
r -= ct.a[i] * sk[i];
// dbg!(r);
}
r = ((r % qi) + qi + qi/8) % qi;
// dbg!(r,qi);
// dbg!(r, qi, 4*r/qi);
4 * r / qi
}
#[derive(Clone)]
pub struct SwitchingKey(pub Array3<CipherTextQ>); // [[[CipherTextQ; KS_EXP]; KS_BASE]; N];
impl Default for SwitchingKey {
fn default() -> Self {
SwitchingKey(Array::default((N,KS_BASE,KS_EXP)))
}
}
pub fn switchingKeyGen(res: &mut SwitchingKey, new_sk: &SecretKey, old_sk: &SecretKeyN, rng: &mut rand::rngs::ThreadRng) {
pub fn switching_key_gen<'b,'c>(res: &mut SwitchingKey, new_sk: ArrayView<'b,i32,Ix1>, old_sk: ArrayView<'c,i32,Ix1>) {
// dbg!(&res[[0,0,0]]);
for i in 0..N {
for j in 0..KS_BASE {
for k in 0..KS_EXP {
// dbg!("switching key",i,j,k);
let mut ct: CipherTextQ = Default::default();
ct.b = - Wrapping(old_sk.0[i]) * Wrapping(j as i32) * Wrapping(KS_TABLE[k]) + Wrapping(sample(&CHI2, rng));
for l in 0..n {
ct.a[l] = rng.gen();
let addend = ct.a[l] * Wrapping(new_sk.0[l]);
// dbg!(ct.b, addend, isize::MAX, isize::MIN);
// let upper = addend <= isize::MAX - ct.b;
// let lower = addend >= isize::MIN - ct.b;
// ct.b = if upper && lower {
// ct.b + addend
// } else if !upper && lower {
// isize::MIN + (addend - (isize::MAX - ct.b)) - 1
// } else {
// isize::MAX - (- addend - (ct.b - isize::MIN)) + 1
// };
ct.b = ct.b + addend;
// dbg!(ct.b);
}
res.0[[i,j,k]] = ct;
// ct.a.par_map_inplace(|x| *x = rand::thread_rng().gen());
ct.a.par_map_inplace(|x| *x = Wrapping(i as i32) * Wrapping(j as i32) * Wrapping(k as i32));
ct.b = // Wrapping(sample(&CHI2))
- Wrapping(old_sk[i]) * Wrapping(j as i32) * Wrapping(KS_TABLE[k])
+ ct.a.dot::<Array<Wrapping<i32>,Ix1>>(&new_sk.mapv(Wrapping));
// + Zip::from(&ct.a).and(new_sk).fold(Wrapping(0), |acc, &x, &y| acc + x * Wrapping(y));
// + ct.a.par_iter().zip(new_sk.par_iter()).map(|(x,y)| x*Wrapping(y)).sum();
res[[i,j,k]] = ct;
}
}
}
// dbg!(&res[[0,0,0]]);
eprintln!("evaluation key 3 -> switching key generation complete");
}
pub fn keySwitch(res: &mut CipherTextQ, ksk: &SwitchingKey, ct: &CipherTextQN) {
pub fn key_switch(res: &mut CipherTextQ, ksk: &SwitchingKey, ct: &CipherTextQN) {
for k in 0..n {
res.a[k] = Wrapping(0);
// println!("res.a[{}] = {}", k, res.a[k]);
}
res.b = ct.b;
// dbg!(res.b);
for i in 0..N {
let a: UZmodQ = Wrapping(- ct.a[i].0 as u32);
for j in 0..KS_BASE {
let mut a: UZmodQ = Wrapping(0) - Wrapping(ct.a[i].0 as u32);
// println!("i = {}, a = {}", i, a);
for j in 0..KS_EXP {
let a0: UZmodQ = a % Wrapping(KS_BASE as u32);
// print!("i = {}, j = {}, ksk[i][{}][j].a = [", i, j, a0.0);
for k in 0..n {
res.a[k] -= ksk.0[[i,a0.0 as usize,j]].a[k];
res.b -= ksk.0[[i,a0.0 as usize,j]].b;
res.a[k] -= ksk[[i,a0.0 as usize,j]].a[k];
// print!("{},", ksk[[i,a0.0 as usize,j]].a[k]);
}
// println!("]");
res.b -= ksk[[i,a0.0 as usize,j]].b;
a /= Wrapping(KS_BASE as u32);
}
}
}
pub fn modSwitch(ct: &mut CipherText, c: &CipherTextQ, rng: &mut rand::rngs::ThreadRng) {
for i in 0..n {
ct.a[i] = round_qQ(c.a[i]);
}
ct.b = round_qQ(c.b);
pub fn mod_switch(ct: &mut CipherText, c: &CipherTextQ) {
ct.a = c.a.mapv(round_q_Q);
ct.b = round_q_Q(c.b);
}
pub fn round_qQ(v_: ZmodQ) -> i32 {
(0.5 + (v_.0 as f64) * (q as f64) / (Q as f64)).floor() as i32
pub fn round_q_Q(v: ZmodQ) -> i32 {
(0.5 + (v.0 as f64) * (q as f64) / (Q as f64)).floor() as i32
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment