Commit 34bf4f87 authored by Sai Tarun Inaganti's avatar Sai Tarun Inaganti

No commit message

No commit message
parent dc9c56c1
...@@ -7,7 +7,8 @@ edition = "2018" ...@@ -7,7 +7,8 @@ edition = "2018"
[dependencies] [dependencies]
fftw = { version = "*" } fftw = { version = "*" }
ndarray = { version = "*" } ndarray = { version = "*", features = ["rayon"] }
num = { version = "*" } num = { version = "*" }
rand = { version = "*" } rand = { version = "*" }
rayon = { version = "*" }
strum = { version = "*" } strum = { version = "*" }
\ No newline at end of file
use std::fs::File;
use std::process::exit; use std::process::exit;
use fftw::types::*; // use fftw::types::*;
use ::fhew::{ use ::fhew::{
*, *,
BinGate::*, BinGate::*,
...@@ -9,8 +10,8 @@ use ::fhew::{ ...@@ -9,8 +10,8 @@ use ::fhew::{
use rand::Rng; use rand::Rng;
fn help(cmd: &String) { fn help(cmd: &String) {
eprintln!("\nusage: {} n\n", cmd); eprintln!("\nusage: {} <count>\n", cmd);
eprintln!(" Generate a secret key sk and evaluation key ek, and repeat the following test n times:"); eprintln!(" Generate a secret key sk and evaluation key ek, and repeat the following test <count> times:");
eprintln!(" - generate random bits b1,b2,b3,b4"); eprintln!(" - generate random bits b1,b2,b3,b4");
eprintln!(" - compute ciphertexts c1, c2, c3 and c4 encrypting b1, b2, b3 and b4 under sk"); eprintln!(" - compute ciphertexts c1, c2, c3 and c4 encrypting b1, b2, b3 and b4 under sk");
eprintln!(" - homomorphically compute the encrypted (c1 NAND c2) NAND (c3 NAND c4)"); eprintln!(" - homomorphically compute the encrypted (c1 NAND c2) NAND (c3 NAND c4)");
...@@ -28,19 +29,19 @@ fn cleartext_gate(v1: bool, v2: bool, gate: BinGate) -> bool { ...@@ -28,19 +29,19 @@ fn cleartext_gate(v1: bool, v2: bool, gate: BinGate) -> bool {
} }
fn eprint_gate(gate: BinGate) { fn eprint_gate(gate: BinGate) {
match gate { match gate {
OR => eprint!(" OR\t"), OR => eprint!("OR"),
AND => eprint!(" AND\t"), AND => eprint!("AND"),
NOR => eprint!(" NOT\t"), NOR => eprint!("NOR"),
NAND => eprint!(" NAND\t") NAND => eprint!("NAND")
} }
} }
fn main() { fn main() {
// assert_eq!(q, 512); // assert_eq!(q, 512);
let mut rng = rand::thread_rng(); // let mut rng = rand::thread_rng();
let mut ffto: FFT = Default::default(); let mut ffto: FFT = Default::default();
// fftSetup(&mut ffto); // fftSetup(&mut ffto);
let mut tTestMSB: RingFFT = Default::default(); let mut t_test_msb: RingFFT = RingFFT();
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
if args.len() != 2 { if args.len() != 2 {
...@@ -48,14 +49,19 @@ fn main() { ...@@ -48,14 +49,19 @@ fn main() {
} }
let count: i32 = args[1].parse().unwrap(); let count: i32 = args[1].parse().unwrap();
eprintln!("Setting up FHEW"); eprintln!("Setting up FHEW");
fhew::setup(&mut ffto, &mut tTestMSB); fhew::setup(&mut ffto, &mut t_test_msb);
eprint!("Generating secret key ... "); eprint!("Generating secret key ... ");
let mut lwe_sk: lwe::SecretKey = Default::default(); let mut lwe_sk: lwe::SecretKey = SecretKey();
lwe::keyGen(&mut lwe_sk, &mut rng); lwe::key_gen(&mut lwe_sk);
// dbg!(&lwe_sk);
eprintln!("Done.\n"); eprintln!("Done.\n");
eprintln!("Generating evaluation key ... this may take a while ... "); eprintln!("Generating evaluation key ... this may take a while ... ");
let mut ek: EvalKey = Default::default(); let mut ek: EvalKey = Default::default();
fhew::keyGen(&mut ek, &lwe_sk, &mut rng, &mut ffto); fhew::key_gen(&mut ek, &lwe_sk, &mut ffto);
// let mut f = File::create("key.txt").unwrap();
// fwrite_ek(&ek, &mut f);
eprintln!("Done.\n"); eprintln!("Done.\n");
eprintln!("Testing depth-2 homomorphic circuits {} times.", count); eprintln!("Testing depth-2 homomorphic circuits {} times.", count);
eprintln!("Circuit shape : (a GATE NOT(b)) GATE (c GATE d)\n"); eprintln!("Circuit shape : (a GATE NOT(b)) GATE (c GATE d)\n");
...@@ -65,48 +71,67 @@ fn main() { ...@@ -65,48 +71,67 @@ fn main() {
let (mut se1, mut se2, mut e1, mut e2, mut e12): (CipherText, CipherText, CipherText, CipherText, CipherText) let (mut se1, mut se2, mut e1, mut e2, mut e12): (CipherText, CipherText, CipherText, CipherText, CipherText)
= Default::default(); = Default::default();
for i in 1..(3*count) { for i in 1..(3*count+1) {
if i % 3 != 0 { // if i != 1 {break;}
v1 = rng.gen::<i32>() % 2; if i % 3 != 0 { // 1,2
v2 = rng.gen::<i32>() % 2; v1 = (rand::thread_rng().gen::<u32>() % 2) as i32;
lwe::encrypt(&mut e1, &lwe_sk, v1, &mut rng); v2 = (rand::thread_rng().gen::<u32>() % 2) as i32;
lwe::encrypt(&mut e2, &lwe_sk, v2, &mut rng); // v1 = 0;
if i % 3 == 1 { // v2 = 0;
eprint!(" NOT\tEnc({}) = ", v2); lwe::encrypt(&mut e1, &lwe_sk, v1);
lwe::encrypt(&mut e2, &lwe_sk, v2);
// for t in 0..n {
// println!("t = {}, lwe_sk[t] = {}, e1.a[t] = {}, e2.a[t] = {}", t, lwe_sk[t], e1.a[t], e2.a[t]);
// }
if i % 3 == 1 { // 1
eprint!("\tNOT\tEnc({}) = ", v2);
let e2_temp = e2.clone(); let e2_temp = e2.clone();
fhew::hom_not(&mut e2, &e2_temp); fhew::hom_not(&mut e2, &e2_temp);
let notv2 = lwe::decrypt(&lwe_sk, &e2); let notv2 = lwe::decrypt(&lwe_sk, &e2);
eprintln!("Enc({})", v2); eprintln!("Enc({})", notv2);
if !(notv2 == !v2) { // dbg!(v2,notv2,!v2,!notv2);
if !(notv2 != v2 && notv2 * v2 == 0) {
eprintln!("ERROR: incorrect NOT Homomorphic computation at iteration {}", i+1); eprintln!("ERROR: incorrect NOT Homomorphic computation at iteration {}", i+1);
exit(1); exit(1);
} }
v2 = !v2; v2 = if v2 == 0 {1} else {0};
} }
} else { } else { // 3
v1 = sv1; v1 = sv1;
v2 = sv2; v2 = sv2;
e1 = se1.clone(); e1 = se1.clone();
e2 = se2.clone(); e2 = se2.clone();
} }
let gate: BinGate = match rng.gen::<usize>() % 4 { // let gate: BinGate = BinGate::NAND;
let gate: BinGate = match rand::thread_rng().gen::<usize>() % 4 {
0 => BinGate::OR, 0 => BinGate::OR,
1 => BinGate::AND, 1 => BinGate::AND,
2 => BinGate::NOR, 2 => BinGate::NOR,
3 => BinGate::NAND, 3 => BinGate::NAND,
_ => BinGate::OR _ => BinGate::OR
}; };
eprint!("Enc({})", v1);
eprint_gate(gate);
eprint!("Enc({}) = ", v2);
fhew::hom_gate(&mut e12, gate, &ek, &e1, &e2, &mut ffto, &tTestMSB, &mut rng); lwe::encrypt(&mut e1, &lwe_sk, v1);
lwe::encrypt(&mut e2, &lwe_sk, v2);
fhew::hom_gate(&mut e12, gate, &ek, &e1, &e2, &mut ffto, &t_test_msb);
let v12: i32 = lwe::decrypt(&lwe_sk, &e12); let v12: i32 = lwe::decrypt(&lwe_sk, &e12);
eprintln!("Enc({})", v12); eprint!("Enc({})\t", v1);
eprint_gate(gate);
eprint!("\tEnc({}) = ", v2);
eprint!("Enc({})", v12);
eprintln!("");
// for j in 0..n {
// println!("i = {}, j = {}, e1.a[j] = {}, e2.a[j] = {}, e12.a[j] = {}", i, j, e1.a[j], e2.a[j], e12.a[j]);
// }
// println!("i = {}\ne1.a = {:?}\ne2.a = {:?}\ne12.a = {:?}", i, e1.a, e2.a, e12.a);
// println!("e1.b = {}, e2.b = {}, e12.b = {}", e1.b, e2.b, e12.b);
match i % 3 { match i % 3 {
0 => eprintln!(""),
1 => { 1 => {
sv1 = v12; sv1 = v12;
se1 = e12.clone(); se1 = e12.clone();
...@@ -114,10 +139,10 @@ fn main() { ...@@ -114,10 +139,10 @@ fn main() {
2 => { 2 => {
sv2 = v12; sv2 = v12;
se2 = e12.clone(); se2 = e12.clone();
},
_ => eprintln!("")
} }
_ => () // println!("i = {}, v1 = {}, v2 = {}, v12 = {}", i, v1, v2, v12);
}
if cleartext_gate(v1 != 0, v2 != 0, gate) != (v12 != 0) { if cleartext_gate(v1 != 0, v2 != 0, gate) != (v12 != 0) {
eprintln!("\n ERROR: incorrect Homomorphic Gate computation at iteration {}", i+1); eprintln!("\n ERROR: incorrect Homomorphic Gate computation at iteration {}", i+1);
exit(1); exit(1);
......
...@@ -10,207 +10,373 @@ use std::io::{ ...@@ -10,207 +10,373 @@ use std::io::{
BufRead, BufRead,
Write Write
}; };
#[derive(Clone)]
struct CtModQ(Array1<CtModQ1>); // [CtModQ1; K2]; // #[derive(Clone,Debug)]
impl Default for CtModQ { // struct CtModQ(Array1<CtModQ1>); // [CtModQ1; K2];
fn default() -> Self { // impl Default for CtModQ {
CtModQ(Array::default(K2)) // fn default() -> Self {
} // CtModQ(Array::default(K2))
// }
// }
pub type CtModQ = Array<ZmodQ,Ix3>;
pub fn CtModQ() -> CtModQ {
Array::default((K2,2,N))
} }
#[derive(Clone)]
struct CtModQ1(Array1<RingModQ>); // [RingModQ; 2]; // #[derive(Clone,Debug)]
impl Default for CtModQ1 { // struct CtModQ1(Array1<RingModQ>); // [RingModQ; 2];
fn default() -> Self { // impl Default for CtModQ1 {
CtModQ1(Array::default(2)) // fn default() -> Self {
} // CtModQ1(Array::default(2))
// }
// }
pub type CtModQ1 = Array<ZmodQ,Ix2>;
pub fn CtModQ1() -> CtModQ1 {
Array::default((2,N))
} }
#[derive(Clone)]
struct DctModQ(Array1<DctModQ1>); // [DctModQ1; K2]; // #[derive(Clone,Debug)]
impl Default for DctModQ { // struct DctModQ(Array1<DctModQ1>); // [DctModQ1; K2];
fn default() -> Self { // impl Default for DctModQ {
DctModQ(Array::default(K2)) // fn default() -> Self {
} // DctModQ(Array::default(K2))
// }
// }
pub type DctModQ = Array<ZmodQ,Ix3>;
pub fn DctModQ() -> DctModQ {
Array::default((K2,K2,N))
} }
#[derive(Clone)]
struct DctModQ1(Array1<RingModQ>); // [RingModQ; K2]; // #[derive(Clone,Debug)]
impl Default for DctModQ1 { // struct DctModQ1(Array1<RingModQ>); // [RingModQ; K2];
fn default() -> Self { // impl Default for DctModQ1 {
DctModQ1(Array::default(K2)) // fn default() -> Self {
} // DctModQ1(Array::default(K2))
// }
// }
pub type DctModQ1 = Array<ZmodQ,Ix2>;
pub fn DctModQ1() -> DctModQ1 {
Array::default((K2,N))
} }
#[derive(Clone)]
struct DctFFT(Array1<DctFFT1>); // [DctFFT1; K2]; // #[derive(Clone,Debug)]
impl Default for DctFFT { // struct DctFFT(Array1<DctFFT1>); // [DctFFT1; K2];
fn default() -> Self { // impl Default for DctFFT {
DctFFT(Array::default(K2)) // fn default() -> Self {
} // DctFFT(Array::default(K2))
// }
// }
pub type DctFFT = Array<c64,Ix3>;
pub fn DctFFT() -> DctFFT {
Array::default((2,K2,N2))
} }
#[derive(Clone)]
struct DctFFT1(Array1<RingFFT>); // [RingFFT; K2]; // #[derive(Clone,Debug)]
impl Default for DctFFT1 { // struct DctFFT1(Array1<RingFFT>); // [RingFFT; K2];
fn default() -> Self { // impl Default for DctFFT1 {
DctFFT1(Array::default(K2)) // fn default() -> Self {
} // DctFFT1(Array::default(K2))
// }
// }
pub type DctFFT1 = Array<c64,Ix2>;
pub fn DctFFT1() -> DctFFT1 {
Array::default((K2,N2))
} }
#[derive(Clone)]
pub struct CtFFT(Array1<CtFFT1>); // [CtFFT1; K2]; // #[derive(Clone,Debug)]
impl Default for CtFFT { // pub struct CtFFT(Array1<CtFFT1>); // [CtFFT1; K2];
fn default() -> Self { // impl Default for CtFFT {
CtFFT(Array::default(K2)) // fn default() -> Self {
} // CtFFT(Array::default(K2))
// }
// }
pub type CtFFT = Array<c64,Ix3>;
pub fn CtFFT() -> CtFFT {
Array::default((K2,2,N2))
} }
#[derive(Clone)]
struct CtFFT1(Array1<RingFFT>); // [RingFFT; 2]; // #[derive(Clone,Debug)]
impl Default for CtFFT1 { // struct CtFFT1(Array1<RingFFT>); // [RingFFT; 2];
fn default() -> Self { // impl Default for CtFFT1 {
CtFFT1(Array::default(2)) // fn default() -> Self {
} // CtFFT1(Array::default(2))
// }
// }
pub type CtFFT1 = Array<c64,Ix2>;
pub fn CtFFT1() -> CtFFT1 {
Array::default((2,N2))
} }
#[derive(Clone)]
pub struct BootstrappingKey(Array3<CtFFT>); // [[[CtFFT; BS_EXP]; BS_BASE]; n]; // #[derive(Clone,Debug)]
impl Default for BootstrappingKey { // pub struct BootstrappingKey(Array3<CtFFT>); // [[[CtFFT; BS_EXP]; BS_BASE]; n];
fn default() -> Self { // impl Default for BootstrappingKey {
BootstrappingKey(Array::default((n,BS_BASE,BS_EXP))) // fn default() -> Self {
} // BootstrappingKey(Array::default((n,BS_BASE,BS_EXP)))
// }
// }
pub type BootstrappingKey = Array<c64,Ix6>;
pub fn BootstrappingKey() -> BootstrappingKey {
let k = Array::default((n,BS_BASE,BS_EXP,K2,2,N2));
eprintln!("evaluation key 1 -> bootstrapping key zeroing complete");
k
} }
#[derive(Default,Clone)]
#[derive(Clone,Debug)]
pub struct EvalKey { pub struct EvalKey {
pub BSkey: BootstrappingKey, pub bskey: BootstrappingKey,
pub KSkey: lwe::SwitchingKey pub kskey: lwe::SwitchingKey
}
impl Default for EvalKey {
fn default() -> Self {
EvalKey {
bskey: BootstrappingKey(),
kskey: SwitchingKey()
}
}
} }
pub fn setup(ffto: &mut FFT, tTestMSB: &mut RingFFT) { pub fn setup(ffto: &mut FFT, t_test_msb: &mut RingFFT) {
fft_setup(ffto); fft_setup(ffto);
let mut tmsb: RingModQ = Default::default(); let mut tmsb: RingModQ = RingModQ();
tmsb.0[0] = Wrapping(-1); tmsb[0] = Wrapping(-1);
for i in 1..N { for i in 1..N {
tmsb.0[i] = Wrapping(1); tmsb[i] = Wrapping(1);
} }
fft_forward(ffto, &tmsb, tTestMSB); fft_forward(ffto, tmsb.view(), t_test_msb.view_mut());
// for i in 0..N2 { // no issue here
// println!("t_test_msb[{}] = {}, {}", i, t_test_msb[i].re, t_test_msb[i].im);
// }
} }
fn fhewEncrypt(ct: &mut CtFFT, skFFT: &RingFFT, m: i32, rng: &mut rand::rngs::ThreadRng, ffto: &mut FFT) { fn fhew_encrypt<'a>(mut ct: ArrayViewMut<'a,c64,Ix3>, sk_fft: &RingFFT, m: i32, ffto: &mut FFT) {
let mut ai: RingFFT = Default::default(); let mut ai: RingFFT = RingFFT();
let mut res: CtModQ = Default::default(); let mut res: CtModQ = CtModQ();
let qi = q as i32; let qi = q as i32;
let Ni = N as i32; let Ni = N as i32;
let mut mm: i32 = (((m & qi) + qi) % qi) * (2*Ni/qi); let mut mm: i32 = (((m % qi) + qi) % qi) * (2*Ni/qi);
let old_mm = mm;
let mut sign: i32 = 1; let mut sign: i32 = 1;
let old_sign = sign;
if mm >= Ni { if mm >= Ni {
mm -= Ni; mm -= Ni;
sign = -1; sign = -1;
} }
// println!("m = {}, old_mm = {}, new_mm = {}, old_sign = {}, new_sign = {}", m, old_mm, mm, old_sign, sign);
// println!("stage 1"); // println!("stage 1");
for i in 0..K2 { for i in 0..K2 {
for k in 0..(Ni as usize) { for k in 0..(Ni as usize) {
res.0[i].0[0].0[k] = rng.gen(); // res[[i,0,k]] = rand::thread_rng().gen();
} res[[i,0,k]] = Wrapping(k as i32);
fft_forward(ffto, &((res.0[i]).0[0]), &mut ai); // println!("res[[{},0,{}]] = {}", i, k, res[[i,0,k]]);
for k in 0..(N2 as usize) { }
ai.0[k] = ai.0[k] * skFFT.0[k]; fft_forward(ffto, res.slice::<Ix1>(s![i,0,..]), ai.view_mut());
} for k in 0..(N2 as usize) { // minute differences, mostly because of precision
fft_backward(ffto, &ai, &mut res.0[i].0[1]); // print!("i = {}, k = {}, ai_old[k] = ({}, {}), , sk_fft[k] = ({}, {}), ", i, k, ai[k].re, ai[k].im, sk_fft[k].re, sk_fft[k].im);
for k in 0..(Ni as usize) { ai[k] = ai[k] * sk_fft[k];
// println!("i = {}, k = {}, res = {}", i, k, res.0[i].0[1].0[k]); // println!("ai_new[k] = ({}, {})", ai[k].re, ai[k].im);
res.0[i].0[1].0[k] += Wrapping(sample(&CHI1, rng) as i32); }
} fft_backward(ffto, ai.view(), res.slice_mut::<Ix1>(s![i,1,..]));
// for k in 0..(Ni as usize) {
// res[[i,1,k]] += Wrapping(sample(&CHI1) as i32);
// }
// for k in 0..(Ni as usize) { // res is clean
// println!("i = {}, k = {}, res[i][0][k] = {}, res[i][1][k] = {}", i, k,
// res[[i,0,k]], res[[i,1,k]]);
// }
} }
// println!("stage 2"); // println!("stage 2");
for i in 0..K { for i in 0..K {
res.0[2*i].0[0].0[mm as usize] += Wrapping(sign) * vgprime[i]; res[[2*i,0,mm as usize]] += Wrapping(sign) * VGPRIME[i];
res.0[2*i+1].0[1].0[mm as usize] += Wrapping(sign) * vgprime[i]; res[[2*i+1,1,mm as usize]] += Wrapping(sign) * VGPRIME[i];
// println!("i = {}, mm = {}, res[2*i+0][0][{}] = {}, res[2*i+1][1][{}] = {}", i, mm, mm as usize, res[[2*i,0,mm as usize]], mm as usize, res[[2*i+1,1,mm as usize]]);
} }
// println!("stage 3"); // println!("stage 3");
for i in 0..K2 { for i in 0..K2 {
for j in 0..2 { for j in 0..2 {
fft_forward(ffto, &res.0[i].0[j], &mut ct.0[i].0[j]); fft_forward(ffto, res.slice::<Ix1>(s![i,j,..]), ct.slice_mut::<Ix1>(s![i,j,..]));
// for k in 0..N2 {
// ct[[i,j,k]] = c64::new(ct[[i,j,k]].re.round(), ct[[i,j,k]].im.round());
// }
} }
} }
// println!("stage 4"); // println!("stage 4");
} }
pub fn keyGen(ek: &mut EvalKey, lweSk: &lwe::SecretKey, rng: &mut rand::rngs::ThreadRng, ffto: &mut FFT) { pub fn key_gen(ek: &mut EvalKey, lwe_sk: &lwe::SecretKey, ffto: &mut FFT) {
let mut fhewSK: SecretKeyN = Default::default(); let mut fhew_sk: SecretKeyN = SecretKeyN();
keyGenN(&mut fhewSK, rng); key_gen_N(&mut fhew_sk);
switchingKeyGen(&mut ek.KSkey, &lweSk, &fhewSK, rng); switching_key_gen(&mut ek.kskey, lwe_sk.view(), fhew_sk.view());
let mut fhewSkFFT: RingFFT = Default::default(); let mut fhew_sk_fft: RingFFT = RingFFT();
let mut fhew_rmq: RingModQ = Default::default(); let fhew_rmq: RingModQ = fhew_sk.mapv(Wrapping);
for i in 0..N { fft_forward(ffto, fhew_rmq.view(), fhew_sk_fft.view_mut());
fhew_rmq.0[i] = Wrapping(fhewSK.0[i]); // for i in 0..N2 {
} // println!("fhew_sk_fft[{}] = ({}, {})", i, fhew_sk_fft[i].re, fhew_sk_fft[i].im);
fft_forward(ffto, &fhew_rmq, &mut fhewSkFFT); // }
for i in 0..n { for i in 0..n {
for j in 1..BS_BASE { for j in 1..BS_BASE {
for k in 0..BS_EXP { for k in 0..BS_EXP {
ek.BSkey.0[[i,j,k]] = Default::default();
fhewEncrypt(&mut ek.BSkey.0[[i,j,k]], &fhewSkFFT, lweSk.0[i]*(j as i32)*(BS_TABLE[k] as i32), rng, ffto); // dbg!("bootstrapping key",i,j,k);
// println!("i0 = {}, i1 = {}, i2 = {}", i, j, k);
fhew_encrypt(ek.bskey.slice_mut::<Ix3>(s![i,j,k,..,..,..]), &fhew_sk_fft, lwe_sk[i]*(j as i32)*(BS_TABLE[k] as i32), ffto); // wrong at 2,4,0,0,0,0
// if i == 2 && j == 4 && k == 0 {
// println!("lwe_sk[i] = {}, j = {}, BS_TABLE[k] = {}, product = {}",
// lwe_sk[i], j as i32, BS_TABLE[k] as i32, lwe_sk[i]*(j as i32)*(BS_TABLE[k] as i32));
// }
} }
} }
} }
// for i in 0..n { // bootstrapping key clear, except for precision issues
// for j in 1..BS_BASE {
// for k in 0..BS_EXP {
// for x in 0..K2 {
// for y in 0..2 {
// for z in 0..N2 {
// let bskey = &ek.bskey[[i,j,k,x,y,z]];
// println!("bskey[{}][{}][{}][{}][{}][{}] = ({}, {})", i, j, k, x, y, z, bskey.re, bskey.im);
// }}}}}}
// for i in 0..N { // switching key is clear
// for j in 0..KS_BASE {
// for k in 0..KS_EXP {
// let kskey = &ek.kskey[[i,j,k]];
// print!("i = {}, j = {}, k = {}, kskey.b = {}, kskey.a = [", i, j, k, kskey.b);
// for t in 0..n {print!("{}, ", kskey.a[t]);}
// print!("]\n");
// }}}
eprintln!("evaluation key 4 -> bootstrapping key generation complete");
} }
fn addToAcc(acc: &mut CtFFT1, c: &CtFFT, ffto: &mut FFT) { fn add_to_acc<'a,'b>(mut acc: ArrayViewMut<'a,c64,Ix2>, c: ArrayView<'b,c64,Ix3>, ffto: &mut FFT) {
let mut ct: CtModQ1 = Default::default(); let mut ct: CtModQ1 = CtModQ1();
let mut dct: DctModQ1 = Default::default(); let mut dct: DctModQ1 = DctModQ1();
let mut dctFFT: DctFFT1 = Default::default(); let mut dct_fft: DctFFT1 = DctFFT1();
for j in 0..2 { for j in 0..2 {
fft_backward(ffto, &acc.0[j], &mut ct.0[j]); fft_backward(ffto, acc.slice(s![j,..]), ct.slice_mut(s![j,..]));
// for k in 0..N {
// print!("j = {}, k = {}, ct[j][k] = {}\t", j, k, ct[[j,k]]);
// if k < N2 {
// print!("acc[j][k] = ({}, {})", acc[[j,k]].re, acc[[j,k]].im);
// }
// println!("");
// }
} }
for j in 0..2 { for j in 0..2 {
for k in 0..N { for k in 0..N {
let mut t: ZmodQ = ct.0[j].0[k] * v_inverse; let mut t: ZmodQ = ct[[j,k]] * V_INVERSE;
// println!("j = {}, k = {}, t = {}, ct[[j,k]] = {}, V_INVERSE = {}", j, k, t, ct[[j,k]], V_INVERSE);
for l in 0..K { for l in 0..K {
let r: ZmodQ = Wrapping((t.0 << g_bits_32[l]) >> g_bits_32[l]); let t_temp = t.clone();
t = Wrapping((t - r).0 >> g_bits[l]); let r: ZmodQ = Wrapping((t.0 << G_BITS_32[l]) >> G_BITS_32[l]);
dct.0[j+2*1].0[k] = r; t = Wrapping((t - r).0 >> G_BITS[l]);
dct[[j+2*l,k]] = r;
// println!("j = {}, k = {}, l = {}, ct[j][k] = {}, dct[j+2*l][k] = {}, t_old = {}, r = {}, t_new = {}",
// j, k, l, ct[[j,k]], dct[[j+2*l,k]], t_temp, r, t);
} }
} }
} }
for j in 0..K2 { for j in 0..K2 {
fft_forward(ffto, &dct.0[j], &mut dctFFT.0[j]); fft_forward(ffto, dct.slice(s![j,..]), dct_fft.slice_mut(s![j,..]));
} }
for j in 0..2 { for j in 0..2 {
for k in 0..N2 { for k in 0..N2 {
acc.0[j].0[k] = c64::new(0.0,0.0); acc[[j,k]] = c64::default();
for l in 0..K2 { for l in 0..K2 {
acc.0[j].0[k] += dctFFT.0[l].0[k] * c.0[l].0[j].0[k]; acc[[j,k]] += dct_fft[[l,k]] * c[[l,j,k]]; // here
// println!("l = {}, j = {}, k = {}, dct_fft[l][k] = ({:.0}, {:.0}), c[l][j][k] = ({:.0}, {:.0})", l, j, k,
// dct_fft[[l,k]].re, dct_fft[[l,k]].im, c[[l,j,k]].re, c[[l,j,k]].im);
} }
} }
// if j == 0 {
// for k in 0..N {
// print!("j = {}, k = {}, ct[j][k] = {}\t", j, k, ct[[j,k]]);
// if k < N2 {
// print!("acc[j][k] = {}", acc[[j,k]]);
// }
// println!("");
// }
// }
} }
} }
fn initializeAcc(acc: &mut CtFFT1, m: i32, ffto: &mut FFT) { fn initialize_acc(acc: &mut CtFFT1, m: i32, ffto: &mut FFT) {
let mut res: CtModQ1 = Default::default(); let mut res: CtModQ1 = CtModQ1();
let qi = q as i32; let qi = q as i32;
let Ni = N as i32; let Ni = N as i32;
let mut mm = (((m % qi) + qi) % qi) * (2*Ni/qi); let mut mm = (((m % qi) + qi) % qi) * (2*Ni/qi);
let mut sign = 1; let old_mm = mm;
let mut sign: i32 = 1;
let old_sign = sign;
if mm >= Ni { if mm >= Ni {
mm -= Ni; mm -= Ni;
sign = -1; sign = -1;
} }
// println!("m = {}, old_mm = {}, new_mm = {}, old_sign = {}, new_sign = {}", m, old_mm, mm, old_sign, sign);
for j in 0..2 { for j in 0..2 {
for k in 0..N { for k in 0..N {
res.0[j].0[k] = Wrapping(0); res[[j,k]] = Wrapping(0);
}
} }
res.0[1].0[mm as usize] += Wrapping(sign) * vgprime[0];
for j in 0..2 {
fft_forward(ffto, &res.0[j], &mut acc.0[j]);
} }
res[[1,mm as usize]] += Wrapping(sign) * VGPRIME[0];
// for j in 0..2 {
// for k in 0..N {
// println!("j = {}, k = {}, res[j][k] = {}", j, k, res[[j,k]]);
// }
// }
for j in 0..2 { // this is the culprit
fft_forward(ffto, res.slice(s![j,..]), acc.slice_mut(s![j,..]));
}
// for j in 0..2 {
// for k in 0..N2 {
// println!("j = {}, k = {}, acc[j][k] = ({}, {})", j, k, acc[[j,k]].re, acc[[j,k]].im);
// }
// }
} }
fn memberTest(t: &RingFFT, c: &CtFFT1, ffto: &mut FFT) -> CipherTextQN { fn member_test(t: &RingFFT, c: &CtFFT1, ffto: &mut FFT) -> CipherTextQN {
let mut temp: RingFFT = Default::default(); let mut temp: RingFFT = RingFFT();
let mut tempModQ: RingModQ = Default::default(); let mut temp_mod_q: RingModQ = RingModQ();
let mut ct: CipherTextQN = Default::default(); let mut ct: CipherTextQN = Default::default();
// for i in 0..N2 {
// dbg!(i,t[[i]]);
// }
// for i in 0..2 {
// for j in 0..N2 {
// dbg!(i,j,c[[i,j]]);
// }
// }
for i in 0..N2 { for i in 0..N2 {
temp.0[i] = (c.0[0].0[i] * t.0[i]).conj(); temp[i] = (c[[0,i]] * t[i]).conj();
} // print!("i = {}, temp[i] = {}, c[1][j] = {}, t[i] = {}\n", i, temp[i], c[[0,i]], t[i]);
fft_backward(ffto, &temp, &mut tempModQ); }
// for i in 0..N2 {
// dbg!(i,temp[[i]]);
// }
// dbg!(&temp);
fft_backward(ffto, temp.view(), temp_mod_q.view_mut());
// for i in 0..N {
// dbg!(i,temp_mod_q[[i]]);
// }
for i in 0..N { for i in 0..N {
ct.a[i] = tempModQ.0[i]; ct.a[i] = temp_mod_q[i];
} }
for i in 0..N2 { for i in 0..N2 {
temp.0[i] = c.0[1].0[i] * t.0[i]; temp[i] = c[[1,i]] * t[i];
} // dbg!(i,temp[i],c[[1,i]],t[i]);
fft_backward(ffto, &temp, &mut tempModQ); }
ct.b = v + tempModQ.0[0]; // for i in 0..N2 {
// dbg!(i,temp[[i]]);
// }
fft_backward(ffto, temp.view(), temp_mod_q.view_mut());
// for i in 0..N {
// dbg!(i,temp_mod_q[[i]]);
// }
ct.b = V + temp_mod_q[0];
// dbg!(ct.b,V,temp_mod_q[0]);
ct ct
} }
pub fn hom_gate( pub fn hom_gate(
...@@ -220,39 +386,59 @@ pub fn hom_gate( ...@@ -220,39 +386,59 @@ pub fn hom_gate(
ct1: &CipherText, ct1: &CipherText,
ct2: &CipherText, ct2: &CipherText,
ffto: &mut FFT, ffto: &mut FFT,
tTestMSB: &RingFFT, t_test_msb: &RingFFT
rng: &mut rand::rngs::ThreadRng
) { ) {
let mut e12: CipherText = Default::default(); let mut e12: CipherText = Default::default();
let qi = q as i32; let qi = q as i32;
for i in 0..n { // for i in 0..n {
if ((ct1.a[i] - ct2.a[i]) % qi) != 0 && ((ct1.a[i] + ct2.a[i]) % qi) != 0 { // if ((ct1.a[i] - ct2.a[i]) % qi) != 0 && ((ct1.a[i] + ct2.a[i]) % qi) != 0 {
break; // break;
} // }
if i == n - 1 { // if i == n - 1 {
panic!("ERROR: Please only use independant ciphertexts as inputs."); // panic!("ERROR: Please only use independant ciphertexts as inputs.");
} // }
} // }
for i in 0..n { for i in 0..n {
e12.a[i] = (2*qi - (ct1.a[i] + ct2.a[i])) % qi; e12.a[i] = (2*qi - (ct1.a[i] + ct2.a[i])) % qi;
} }
e12.b = (GATE_CONST[gate as usize] as i32) - (ct1.b + ct2.b) % qi; e12.b = (GATE_CONST[gate as usize] as i32) - ((ct1.b + ct2.b) % qi);
let mut acc: CtFFT1 = Default::default();
initializeAcc(&mut acc, (e12.b + qi/4) % qi, ffto); // println!("e12.b = {}, e12.a = {}", e12.b, e12.a);
// println!("ct1.b = {}, ct1.a = {}", ct1.b, ct1.a);
// println!("ct2.b = {}, ct2.a = {}", ct2.b, ct2.a);
let mut acc: CtFFT1 = CtFFT1();
initialize_acc(&mut acc, (e12.b + qi/4) % qi, ffto); // all clear
// for i in 0..2 {
// for j in 0..N2 {
// println!("acc[{}][{}] = {}, {}", i, j, acc[[i,j]].re, acc[[i,j]].im);
// }
// }
for i in 0..n { for i in 0..n {
let mut a = (qi - e12.a[i] % qi) % qi; let mut a = (qi - e12.a[i] % qi) % qi;
for k in 0..BS_EXP { for k in 0..BS_EXP {
let a0 = a % (BS_BASE as i32); let a0 = a % (BS_BASE as i32);
// println!("i = {}, a0 = {}, k = {}", i, a0, k); // all clear
if a0 != 0 { if a0 != 0 {
addToAcc(&mut acc, &ek.BSkey.0[[i,a0 as usize,k]], ffto); add_to_acc(acc.view_mut(), ek.bskey.slice(s![i,a0 as usize,k,..,..,..]), ffto); // trouble here
} }
a /= BS_BASE as i32; a /= BS_BASE as i32;
} }
} }
let eQN: CipherTextQN = memberTest(tTestMSB, &acc, ffto); // for i in 0..2 {
let mut eQ: CipherTextQ = Default::default(); // for j in 0..N2 {
keySwitch(&mut eQ, &ek.KSkey, &eQN); // println!("acc[{}][{}] = ({}, {})", i, j, acc[[i,j]].re, acc[[i,j]].im);
modSwitch(res, &eQ, rng); // }
// }
let e_qn: CipherTextQN = member_test(t_test_msb, &acc, ffto);
// println!("e_qn.b = {}, e_qn.a = {}", e_qn.b, e_qn.a);
let mut e_q: CipherTextQ = Default::default();
key_switch(&mut e_q, &ek.kskey, &e_qn);
// print!("e_q.b = {}, e_q.a = [", e_q.b);
// for t in 0..n {println!("{}, ", e_q.a[t]);}
// print!("]\n");
mod_switch(res, &e_q);
// dbg!(res);
} }
pub fn hom_nand( pub fn hom_nand(
res: &mut CipherText, res: &mut CipherText,
...@@ -260,9 +446,8 @@ pub fn hom_nand( ...@@ -260,9 +446,8 @@ pub fn hom_nand(
ct1: &CipherText, ct1: &CipherText,
ct2: &CipherText, ct2: &CipherText,
ffto: &mut FFT, ffto: &mut FFT,
tTestMSB: &RingFFT, t_test_msb: &RingFFT) {
rng: &mut rand::rngs::ThreadRng) { hom_gate(res, NAND, ek, ct1, ct2, ffto, t_test_msb)
hom_gate(res, NAND, ek, ct1, ct2, ffto, tTestMSB, rng)
} }
pub fn hom_not(res: &mut CipherText, ct: &CipherText) { pub fn hom_not(res: &mut CipherText, ct: &CipherText) {
let qi = q as i32; let qi = q as i32;
...@@ -272,14 +457,14 @@ pub fn hom_not(res: &mut CipherText, ct: &CipherText) { ...@@ -272,14 +457,14 @@ pub fn hom_not(res: &mut CipherText, ct: &CipherText) {
res.b = (5 * qi / 4 - ct.b) % qi; res.b = (5 * qi / 4 - ct.b) % qi;
} }
pub fn fwriteEK(ek: &EvalKey, f: &mut File) { pub fn fwrite_ek(ek: &EvalKey, f: &mut File) {
for i in 0..n { for i in 0..n {
for j in 1..BS_BASE { for j in 1..BS_BASE {
for k in 0..BS_EXP { for k in 0..BS_EXP {
for l in 0..K2 { for l in 0..K2 {
for m in 0..2 { for m in 0..2 {
for n_ in 0..N2 { for n_ in 0..N2 {
writeln!(*f, "{}", ek.BSkey.0[[i,j,k]].0[l].0[m].0[n_]).unwrap(); writeln!(*f, "{}", ek.bskey[[i,j,k,l,m,n_]]).unwrap();
} }
} }
} }
...@@ -290,15 +475,15 @@ pub fn fwriteEK(ek: &EvalKey, f: &mut File) { ...@@ -290,15 +475,15 @@ pub fn fwriteEK(ek: &EvalKey, f: &mut File) {
for j in 0..KS_BASE { for j in 0..KS_BASE {
for k in 0..KS_EXP { for k in 0..KS_EXP {
for l in 0..n { for l in 0..n {
writeln!(*f, "{}", ek.KSkey.0[[i,j,k]].a[l]).unwrap(); writeln!(*f, "{}", ek.kskey[[i,j,k]].a[l]).unwrap();
} }
writeln!(*f, "{}", ek.KSkey.0[[i,j,k]].b).unwrap(); writeln!(*f, "{}", ek.kskey[[i,j,k]].b).unwrap();
} }
} }
} }
// f.sync_all().unwrap(); // f.sync_all().unwrap();
} }
pub fn freadEK(f: &File) -> EvalKey { pub fn fread_ek(f: &File) -> EvalKey {
use std::str::FromStr; use std::str::FromStr;
let mut ek: EvalKey = Default::default(); let mut ek: EvalKey = Default::default();
let mut b = std::io::BufReader::new(f); let mut b = std::io::BufReader::new(f);
...@@ -310,7 +495,7 @@ pub fn freadEK(f: &File) -> EvalKey { ...@@ -310,7 +495,7 @@ pub fn freadEK(f: &File) -> EvalKey {
for m in 0..2 { for m in 0..2 {
for n_ in 0..N2 { for n_ in 0..N2 {
b.read_line(&mut s).unwrap(); b.read_line(&mut s).unwrap();
ek.BSkey.0[[i,j,k]].0[l].0[m].0[n_] = c64::from_str(&s).unwrap(); ek.bskey[[i,j,k,l,m,n_]] = c64::from_str(&s).unwrap();
} }
} }
} }
...@@ -322,10 +507,10 @@ pub fn freadEK(f: &File) -> EvalKey { ...@@ -322,10 +507,10 @@ pub fn freadEK(f: &File) -> EvalKey {
for k in 0..KS_EXP { for k in 0..KS_EXP {
for l in 0..n { for l in 0..n {
b.read_line(&mut s).unwrap(); b.read_line(&mut s).unwrap();
ek.KSkey.0[[i,j,k]].a[l] = Wrapping(i32::from_str(&s).unwrap()); ek.kskey[[i,j,k]].a[l] = Wrapping(i32::from_str(&s).unwrap());
} }
b.read_line(&mut s).unwrap(); b.read_line(&mut s).unwrap();
ek.KSkey.0[[i,j,k]].b = Wrapping(i32::from_str(&s).unwrap()); ek.kskey[[i,j,k]].b = Wrapping(i32::from_str(&s).unwrap());
} }
} }
} }
......
...@@ -9,26 +9,26 @@ use ndarray::*; ...@@ -9,26 +9,26 @@ use ndarray::*;
use rand::Rng; use rand::Rng;
use std::num::Wrapping; use std::num::Wrapping;
pub const n: usize = 500; pub const n: usize = 10;
pub const N: usize = 1024; pub const N: usize = 1024;
pub const N2: usize = N/2; pub const N2: usize = N/2+1;
pub const K: usize = 3; pub const K: usize = 3; // K
pub const K2: usize = 6; pub const K2: usize = 6;
pub const Q: usize = 1 << 32; pub const Q: usize = 1 << 32; // Q
pub const q: usize = 512; pub const q: usize = 512;
pub const q2: usize = 256; pub const q2: usize = q/2;
type ZmodQ = Wrapping<i32>; pub type ZmodQ = Wrapping<i32>;
type UZmodQ = Wrapping<u32>; type UZmodQ = Wrapping<u32>;
const v: ZmodQ = Wrapping((1 << 29) + 1); const V: ZmodQ = Wrapping((1 << 29) + 1);
const v_inverse: ZmodQ = Wrapping(-536870911); // 3758096385; const V_INVERSE: ZmodQ = Wrapping(-536870911); // 3758096385; 1/V mod Q
const vgprime: [ZmodQ; 3] = [Wrapping(536870913), Wrapping(2048), Wrapping(4194304)]; // [v, v<<11, v<<22]; const VGPRIME: [ZmodQ; 3] = [Wrapping(V.0), Wrapping(V.0 << 11), Wrapping(V.0 << 22)]; // [V, V<<11, V<<22];
const g_bits: [isize; 3] = [11, 11, 10]; const G_BITS: [isize; 3] = [11, 11, 10];
const g_bits_32: [isize; 3] = [21, 21, 22]; const G_BITS_32: [isize; 3] = [21, 21, 22];
pub const KS_BASE: usize = 25; pub const KS_BASE: usize = 25;
pub const KS_EXP: usize = 7; pub const KS_EXP: usize = 7;
...@@ -46,22 +46,30 @@ pub const BS_BASE: usize = 23; ...@@ -46,22 +46,30 @@ pub const BS_BASE: usize = 23;
pub const BS_EXP: usize = 2; pub const BS_EXP: usize = 2;
pub const BS_TABLE: [usize; 2] = [1, 23]; pub const BS_TABLE: [usize; 2] = [1, 23];
#[derive(Clone)] // #[derive(Clone,Debug)]
pub struct RingModQ(pub Array1<ZmodQ>); // [ZmodQ; N]; // pub struct RingModQ(pub Array1<ZmodQ>); // [ZmodQ; N];
impl Default for RingModQ { // impl Default for RingModQ {
fn default() -> RingModQ { // fn default() -> RingModQ {
RingModQ(Array::default(N)) // RingModQ(Array::default(N))
} // }
// }
pub type RingModQ = Array<ZmodQ,Ix1>;
pub fn RingModQ() -> RingModQ {
Array::default(N)
} }
#[derive(Clone)] // #[derive(Clone,Debug)]
pub struct RingFFT(pub Array1<c64>); // [c64; N2]; // pub struct RingFFT(pub Array1<c64>); // [c64; N2];
impl Default for RingFFT { // impl Default for RingFFT {
fn default() -> RingFFT { // fn default() -> RingFFT {
RingFFT(Array::default(N2)) // RingFFT(Array::default(N2))
} // }
// }
pub type RingFFT = Array<c64,Ix1>;
pub fn RingFFT() -> RingFFT {
Array::default(N2)
} }
#[derive(Copy,Clone)] #[derive(Copy,Clone,Debug)]
pub enum BinGate { OR, AND, NOR, NAND } pub enum BinGate { OR, AND, NOR, NAND }
const GATE_CONST: [usize; 4] = [15*q/8, 9*q/8, 11*q/8, 13*q/8]; const GATE_CONST: [usize; 4] = [15*q/8, 9*q/8, 11*q/8, 13*q/8];
...@@ -75,11 +83,11 @@ struct Distrib { ...@@ -75,11 +83,11 @@ struct Distrib {
table: &'static [f64] table: &'static [f64]
} }
fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 { fn sample(chi: &Distrib) -> i32 {
// dbg!(chi.std_dev); // dbg!(chi.std_dev);
if chi.max != 0 { if chi.max != 0 { // CHI1, CHI_BINARY
// println!("path 1"); // println!("path 1");
let r: f64 = rng.gen(); let r: f64 = rand::thread_rng().gen();
for i in 0..chi.max { for i in 0..chi.max {
if r <= chi.table[i as usize] { if r <= chi.table[i as usize] {
return i - chi.offset; return i - chi.offset;
...@@ -89,24 +97,24 @@ fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 { ...@@ -89,24 +97,24 @@ fn sample(chi: &Distrib, rng: &mut rand::rngs::ThreadRng) -> i32 {
} }
let mut r: f64; let mut r: f64;
let s = chi.std_dev; let s = chi.std_dev;
if s < 500.0 { if s < 500.0 { // CHI3
// println!("path 2"); // println!("path 2");
let mut x: i32; let mut x: i32;
let maxx= (s*8.0).ceil() as i32; let maxx= (s*8.0).ceil() as i32;
loop { loop {
x = rng.gen::<i32>() % (2*maxx + 1) - maxx; x = rand::thread_rng().gen::<i32>() % (2*maxx + 1) - maxx;
r = rng.gen(); r = rand::thread_rng().gen();
// println!("x = {}, y = {}, z = {}", r, x, s); // println!("x = {}, y = {}, z = {}", r, x, s);
if r < (- (x*x) as f64 / (2.0*s*s)).exp() { if r < (- (x*x) as f64 / (2.0*s*s)).exp() {
return x; return x;
} }
} }
} else { } else { // CHI2
// println!("path 3"); // println!("path 3");
let mut x: f64; let mut x: f64;
loop { loop {
x = 16.0 * rng.gen::<f64>() - 8.0; x = 16.0 * rand::thread_rng().gen::<f64>() - 8.0;
r = rng.gen(); r = rand::thread_rng().gen();
// println!("r = {}\tx = {}\ts = {}", r, x, s); // println!("r = {}\tx = {}\ts = {}", r, x, s);
if r < (- x*x / 2.0).exp() { if r < (- x*x / 2.0).exp() {
return (0.5 + x*s).floor() as i32; return (0.5 + x*s).floor() as i32;
...@@ -174,30 +182,41 @@ impl Default for FFT { ...@@ -174,30 +182,41 @@ impl Default for FFT {
in_: AlignedVec::new(N*2), in_: AlignedVec::new(N*2),
out: AlignedVec::new(N+1), out: AlignedVec::new(N+1),
plan_fft_forw: R2CPlan64::aligned(&[N*2], Flag::PATIENT).unwrap(), plan_fft_forw: R2CPlan64::aligned(&[N*2], Flag::PATIENT).unwrap(),
plan_fft_back: C2RPlan64::aligned(&[N*2], Flag::PATIENT).unwrap() plan_fft_back: C2RPlan64::aligned(&[N*2], Flag::PATIENT | Flag::PRESERVEINPUT).unwrap()
} }
} }
} }
pub fn fft_setup(ffto: &mut FFT) { pub fn fft_setup(ffto: &mut FFT) {
*ffto = Default::default(); *ffto = Default::default();
} }
pub fn fft_forward(ffto: &mut FFT, val: &RingModQ, res: &mut RingFFT) { pub fn fft_forward<'a,'b>(ffto: &mut FFT, val: ArrayView<'a,ZmodQ,Ix1>, mut res: ArrayViewMut<'b,c64,Ix1>) {
for k in 0..N { for k in 0..N {
ffto.in_[k] = val.0[k].0 as f64; ffto.in_[k] = val[k].0 as f64;
ffto.in_[k+N] = 0.0; ffto.in_[k+N] = 0.0;
} }
ffto.plan_fft_forw.r2c(&mut ffto.in_, &mut ffto.out).unwrap(); ffto.plan_fft_forw.r2c(&mut ffto.in_, &mut ffto.out).unwrap();
for k in 0..N2 { for k in 0..(N2-1) {
res.0[k] = ffto.out[2*k+1]; res[k] = ffto.out[2*k+1];
// res[k] = ffto.out[k];
} }
} }
pub fn fft_backward(ffto: &mut FFT, val: &RingFFT, res: &mut RingModQ) { pub fn fft_backward<'a,'b>(ffto: &mut FFT, val: ArrayView<'a,c64,Ix1>, mut res: ArrayViewMut<'b,ZmodQ,Ix1>) {
for k in 0..N2 { for k in 0..N2 {
ffto.out[2*k+1] = val.0[k]/c64::new(N as f64,0.0);
ffto.out[2*k] = c64::new(0.0,0.0); ffto.out[2*k] = c64::new(0.0,0.0);
if k < N2 - 1 {
ffto.out[2*k+1] = val[k]/c64::new(N as f64,0.0);
}
// ffto.out[k] = val[k]; // /c64::new(N as f64,0.0);
} }
ffto.plan_fft_back.c2r(&mut ffto.out, &mut ffto.in_).unwrap(); ffto.plan_fft_back.c2r(&mut ffto.out, &mut ffto.in_).unwrap();
for k in 0..N { for k in 0..N {
res.0[k] = Wrapping(ffto.in_[k].round() as i32); // let max: i64 = i32::MAX as i64;
// let min: i64 = i32::MIN as i64;
// let div: i64 = max - min + 1;
// let mut t = ffto.in_[k].round() as i64 % div;
// t = if t > max { t - div } else if t < min { t + div } else { t };
// res[k] = Wrapping(t as i32);
res[k] = Wrapping(ffto.in_[k].round().rem_euclid(2f64.powi(32)) as u32 as i32);
} }
} }
use rand::Rng; use rand::Rng;
use crate::*; use crate::*;
use ndarray::parallel::prelude::*;
use ndarray::linalg::*;
use rayon::prelude::*;
#[derive(Clone)] #[derive(Clone,Debug)]
pub struct CipherText { pub struct CipherText {
pub a: Array1<i32>, // [isize; n], pub a: Array1<i32>, // [isize; n],
pub b: i32 pub b: i32
...@@ -14,7 +17,7 @@ impl Default for CipherText { ...@@ -14,7 +17,7 @@ impl Default for CipherText {
} }
} }
} }
#[derive(Clone)] #[derive(Clone,Debug)]
pub struct CipherTextQ { pub struct CipherTextQ {
pub a: Array1<ZmodQ>, // [ZmodQ; n], pub a: Array1<ZmodQ>, // [ZmodQ; n],
pub b: ZmodQ pub b: ZmodQ
...@@ -27,7 +30,7 @@ impl Default for CipherTextQ { ...@@ -27,7 +30,7 @@ impl Default for CipherTextQ {
} }
} }
} }
#[derive(Clone)] #[derive(Clone,Debug)]
pub struct CipherTextQN { pub struct CipherTextQN {
pub a: Array1<ZmodQ>, // [ZmodQ; n], pub a: Array1<ZmodQ>, // [ZmodQ; n],
pub b: ZmodQ pub b: ZmodQ
...@@ -35,123 +38,145 @@ pub struct CipherTextQN { ...@@ -35,123 +38,145 @@ pub struct CipherTextQN {
impl Default for CipherTextQN { impl Default for CipherTextQN {
fn default() -> Self { fn default() -> Self {
CipherTextQN { CipherTextQN {
a: Array::default(n), a: Array::default(N),
b: Default::default() b: Default::default()
} }
} }
} }
#[derive(Clone)] // #[derive(Clone,Debug)]
pub struct SecretKey(pub Array1<i32>); // [isize; n]; // pub struct SecretKey(pub Array1<i32>); // [isize; n];
impl Default for SecretKey { // impl Default for SecretKey {
fn default() -> Self { // fn default() -> Self {
SecretKey(Array::default(n)) // SecretKey(Array::default(n))
} // }
// }
pub type SecretKey = Array<i32,Ix1>;
pub fn SecretKey() -> SecretKey {
Array::default(n)
}
// #[derive(Clone,Debug)]
// pub struct SecretKeyN(pub Array1<i32>); // [isize; N];
// impl Default for SecretKeyN {
// fn default() -> Self {
// SecretKeyN(Array::default(N))
// }
// }
pub type SecretKeyN = Array<i32,Ix1>;
pub fn SecretKeyN() -> SecretKeyN {
Array::default(N)
} }
#[derive(Clone)]
pub struct SecretKeyN(pub Array1<i32>); // [isize; N]; // #[derive(Clone,Debug)]
impl Default for SecretKeyN { // pub struct SwitchingKey(pub Array3<CipherTextQ>); // [[[CipherTextQ; KS_EXP]; KS_BASE]; N];
fn default() -> Self { // impl Default for SwitchingKey {
SecretKeyN(Array::default(N)) // fn default() -> Self {
} // SwitchingKey(Array::default((N,KS_BASE,KS_EXP)))
// }
// }
pub type SwitchingKey = Array<CipherTextQ,Ix3>;
pub fn SwitchingKey() -> SwitchingKey {
let k = Array::default((N,KS_BASE,KS_EXP));
eprintln!("evaluation key 2 -> switching key zeroing complete");
k
} }
const qi: i32 = q as i32; const qi: i32 = q as i32;
pub fn keyGen(sk: &mut SecretKey, rng: &mut rand::rngs::ThreadRng) { pub fn key_gen(sk: &mut SecretKey) {
loop { // loop {
let mut s = 0; let mut s = 0;
let mut ss = 0; let mut ss = 0;
for i in 0..n { for i in 0..n {
sk.0[i] = sample(&CHI_BINARY, rng); // sk[i] = sample(&CHI_BINARY);
s += sk.0[i]; sk[i] = BINARY_TABLE[i%3].floor() as i32;
ss += (sk.0[i]).abs(); s += sk[i];
} ss += (sk[i]).abs();
if s.abs() > 5 || (ss - (n as i32) / 2).abs() > 5 {
continue;
} else {
break;
}
} }
// if s.abs() > 5 || (ss - (n as i32) / 2).abs() > 5 {
// continue;
// } else {
// break;
// }
// }
} }
pub fn keyGenN(sk: &mut SecretKeyN, rng: &mut rand::rngs::ThreadRng) { pub fn key_gen_N(sk: &mut SecretKeyN) {
for i in 0..N { for i in 0..N {
sk.0[i] = sample(&CHI1, rng); // sk[i] = sample(&CHI1);
sk[i] = (i as i32) % 2;
} }
} }
pub fn encrypt(ct: &mut CipherText, sk: &SecretKey, m: i32, rng: &mut rand::rngs::ThreadRng) { pub fn encrypt(ct: &mut CipherText, sk: &SecretKey, m: i32) {
ct.b = (m % 4) * qi / 4 + sample(&CHI3, rng); ct.b = (m % 4) * qi / 4; // + sample(&CHI3);
for i in 0..n { for i in 0..n {
ct.a[i] = rng.gen::<i32>() % qi; // ct.a[i] = rand::thread_rng().gen::<i32>() % qi;
ct.b = (ct.b + ct.a[i] * sk.0[i]) % qi; ct.a[i] = (i as i32) % qi;
ct.b = (ct.b + ct.a[i] * sk[i]) % qi;
// println!("i = {}, ct.b = {}", i, ct.b);
} }
// println!("m = {}, ct.b = {}, ct->a = {:?}", m, ct.b, ct.a);
} }
pub fn decrypt(sk: &SecretKey, ct: &CipherText) -> i32 { pub fn decrypt(sk: &SecretKey, ct: &CipherText) -> i32 {
let mut r = ct.b; let mut r = ct.b;
// dbg!(r);
for i in 0..n { for i in 0..n {
r -= ct.a[i] * sk.0[i]; r -= ct.a[i] * sk[i];
// dbg!(r);
} }
r = ((r % qi) + qi + qi/8) % qi; r = ((r % qi) + qi + qi/8) % qi;
// dbg!(r,qi);
// dbg!(r, qi, 4*r/qi);
4 * r / qi 4 * r / qi
} }
#[derive(Clone)] pub fn switching_key_gen<'b,'c>(res: &mut SwitchingKey, new_sk: ArrayView<'b,i32,Ix1>, old_sk: ArrayView<'c,i32,Ix1>) {
pub struct SwitchingKey(pub Array3<CipherTextQ>); // [[[CipherTextQ; KS_EXP]; KS_BASE]; N]; // dbg!(&res[[0,0,0]]);
impl Default for SwitchingKey {
fn default() -> Self {
SwitchingKey(Array::default((N,KS_BASE,KS_EXP)))
}
}
pub fn switchingKeyGen(res: &mut SwitchingKey, new_sk: &SecretKey, old_sk: &SecretKeyN, rng: &mut rand::rngs::ThreadRng) {
for i in 0..N { for i in 0..N {
for j in 0..KS_BASE { for j in 0..KS_BASE {
for k in 0..KS_EXP { for k in 0..KS_EXP {
// dbg!("switching key",i,j,k);
let mut ct: CipherTextQ = Default::default(); let mut ct: CipherTextQ = Default::default();
ct.b = - Wrapping(old_sk.0[i]) * Wrapping(j as i32) * Wrapping(KS_TABLE[k]) + Wrapping(sample(&CHI2, rng)); // ct.a.par_map_inplace(|x| *x = rand::thread_rng().gen());
for l in 0..n { ct.a.par_map_inplace(|x| *x = Wrapping(i as i32) * Wrapping(j as i32) * Wrapping(k as i32));
ct.a[l] = rng.gen(); ct.b = // Wrapping(sample(&CHI2))
let addend = ct.a[l] * Wrapping(new_sk.0[l]); - Wrapping(old_sk[i]) * Wrapping(j as i32) * Wrapping(KS_TABLE[k])
// dbg!(ct.b, addend, isize::MAX, isize::MIN); + ct.a.dot::<Array<Wrapping<i32>,Ix1>>(&new_sk.mapv(Wrapping));
// let upper = addend <= isize::MAX - ct.b; // + Zip::from(&ct.a).and(new_sk).fold(Wrapping(0), |acc, &x, &y| acc + x * Wrapping(y));
// let lower = addend >= isize::MIN - ct.b; // + ct.a.par_iter().zip(new_sk.par_iter()).map(|(x,y)| x*Wrapping(y)).sum();
// ct.b = if upper && lower { res[[i,j,k]] = ct;
// ct.b + addend
// } else if !upper && lower {
// isize::MIN + (addend - (isize::MAX - ct.b)) - 1
// } else {
// isize::MAX - (- addend - (ct.b - isize::MIN)) + 1
// };
ct.b = ct.b + addend;
// dbg!(ct.b);
}
res.0[[i,j,k]] = ct;
} }
} }
} }
// dbg!(&res[[0,0,0]]);
eprintln!("evaluation key 3 -> switching key generation complete");
} }
pub fn keySwitch(res: &mut CipherTextQ, ksk: &SwitchingKey, ct: &CipherTextQN) { pub fn key_switch(res: &mut CipherTextQ, ksk: &SwitchingKey, ct: &CipherTextQN) {
for k in 0..n { for k in 0..n {
res.a[k] = Wrapping(0); res.a[k] = Wrapping(0);
// println!("res.a[{}] = {}", k, res.a[k]);
} }
res.b = ct.b; res.b = ct.b;
// dbg!(res.b);
for i in 0..N { for i in 0..N {
let a: UZmodQ = Wrapping(- ct.a[i].0 as u32); let mut a: UZmodQ = Wrapping(0) - Wrapping(ct.a[i].0 as u32);
for j in 0..KS_BASE { // println!("i = {}, a = {}", i, a);
for j in 0..KS_EXP {
let a0: UZmodQ = a % Wrapping(KS_BASE as u32); let a0: UZmodQ = a % Wrapping(KS_BASE as u32);
// print!("i = {}, j = {}, ksk[i][{}][j].a = [", i, j, a0.0);
for k in 0..n { for k in 0..n {
res.a[k] -= ksk.0[[i,a0.0 as usize,j]].a[k]; res.a[k] -= ksk[[i,a0.0 as usize,j]].a[k];
res.b -= ksk.0[[i,a0.0 as usize,j]].b; // print!("{},", ksk[[i,a0.0 as usize,j]].a[k]);
} }
// println!("]");
res.b -= ksk[[i,a0.0 as usize,j]].b;
a /= Wrapping(KS_BASE as u32);
} }
} }
} }
pub fn modSwitch(ct: &mut CipherText, c: &CipherTextQ, rng: &mut rand::rngs::ThreadRng) { pub fn mod_switch(ct: &mut CipherText, c: &CipherTextQ) {
for i in 0..n { ct.a = c.a.mapv(round_q_Q);
ct.a[i] = round_qQ(c.a[i]); ct.b = round_q_Q(c.b);
}
ct.b = round_qQ(c.b);
} }
pub fn round_qQ(v_: ZmodQ) -> i32 { pub fn round_q_Q(v: ZmodQ) -> i32 {
(0.5 + (v_.0 as f64) * (q as f64) / (Q as f64)).floor() as i32 (0.5 + (v.0 as f64) * (q as f64) / (Q as f64)).floor() as i32
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment