from __future__ import annotations import hashlib import itertools import math import random from secrets import SystemRandom from typing import List, Union from .week1 import BloodType, blood_cell_compatibility_lookup from .week4 import gen_prime class ElGamal: def __init__(self, g, q, p): self.gen_ = g self.order = q self.p = p def gen_key(self): key = SystemRandom().randint(1, self.order) while math.gcd(self.order, key) != 1: key = SystemRandom().randint(1, self.order) return key def gen(self, sk): h = pow(self.gen_, sk, self.order) return (self.gen_, h) def enc(self, m, pk): # sample random r \in Zq r = SystemRandom().randint(1, self.order) g, h = pk s = pow(h, r, self.order) p = pow(g, r, self.order) tmp = int.from_bytes(m, "big") c = s * tmp return c, p def dec(self, c, sk): c1, c2 = c h = pow(c2, sk, self.order) m = c1 // h return m.to_bytes(16, "big") def ogen(self): s = SystemRandom().randint(1, self.order) h = pow(s, 2, self.order) return self.gen_, h def sha256(b: bytes) -> bytes: return hashlib.sha256(b).digest() def rand_bytes(): return SystemRandom().getrandbits(128).to_bytes(16, "big") def xor_bytes(a: bytes, b: bytes, k=32) -> bytes: return (int.from_bytes(a, "big") ^ int.from_bytes(b, "big")).to_bytes(k, "big") class Gate: def __init__(self, left: Union[Gate, InputWire], right: Union[Gate, InputWire], index: int) -> None: self.left = left self.right = right self.i = index self.k = { 0: rand_bytes(), 1: rand_bytes() } self.output = None c_prime = {} for a, b in itertools.product((0, 1), repeat=2): c_prime[(a, b)] = xor_bytes( sha256(self.left.k[a] + self.right.k[b] + self.i.to_bytes(1, "big")), self.k[self.f(a, b)] + bytes(16) ) pi = list(itertools.product((0, 1), repeat=2)) random.shuffle(pi) self.c = {i: c_prime[p] for i, p in enumerate(pi)} def f(self, a, b): raise NotImplemented class ImplyGate(Gate): def f(self, a, b): return a >= b class AndGate(Gate): def f(self, a, b): return a * b class InputWire: def __init__(self, index) -> None: self.i = index self.k = { 0: rand_bytes(), 1: rand_bytes() } self.output = None class Circuit: def __init__(self, input_wires: List[InputWire], gates: List[Gate]) -> None: self.input_wires = input_wires self.gates = gates @property def d(self): return self.gates[-1].k def evaluate(self, x: List[bytes]) -> bytes: for i, input_wire in enumerate(self.input_wires): input_wire.output = x[i] for gate in self.gates: for j in range(4): xor = xor_bytes( sha256(gate.left.output + gate.right.output + gate.i.to_bytes(1, "big")), gate.c[j] ) k_prime, tau = xor[:16], xor[16:] if tau == bytes(16): gate.output = k_prime return self.gates[-1].output def encode(e: List[InputWire], x: List[int]) -> List[bytes]: return [e[i].k[xi] for i, xi in enumerate(x)] class Alice: def __init__(self, ra, rb, rs, elgamal): self.elgamal = elgamal self.sks = None self.input = [ra, rb, rs] self.keys = None def send_pks(self): pks = [] self.sks = [] for idx, input_ in enumerate(self.input): sk = self.elgamal.gen_key() pk = self.elgamal.gen(sk) self.sks.append(sk) fake_pk = self.elgamal.ogen() pk_tuple = [fake_pk] pk_tuple.insert(input_, pk) pks.append(pk_tuple) return pks def retrieve(self, circuit, bob_keys, ciphers): self.keys = [] for idx, sk in enumerate(self.sks): self.keys.append(self.elgamal.dec(ciphers[idx][self.input[idx]], sk)) all_keys = self.keys + bob_keys res = circuit.evaluate(all_keys) if circuit.d[0] == res: return 0 if circuit.d[1] == res: return 1 raise Exception("Fuck you") class Bob: def __init__(self, da, db, ds, elgamal): input_wire1 = InputWire(0) input_wire2 = InputWire(1) input_wire3 = InputWire(2) input_wire4 = InputWire(3) input_wire5 = InputWire(4) input_wire6 = InputWire(5) impl_gate_1 = ImplyGate(input_wire1, input_wire4, 6) impl_gate_2 = ImplyGate(input_wire2, input_wire5, 7) impl_gate_3 = ImplyGate(input_wire3, input_wire6, 8) and_gate_1 = AndGate(impl_gate_1, impl_gate_2, 9) and_gate_2 = AndGate(and_gate_1, impl_gate_3, 10) self.circuit = Circuit( input_wires=[input_wire1, input_wire2, input_wire3, input_wire4, input_wire5, input_wire6], gates=[impl_gate_1, impl_gate_2, impl_gate_3, and_gate_1, and_gate_2] ) self.own_keys = encode([input_wire4, input_wire5, input_wire6], [da, db, ds]) self.key_set = [x.k.values() for x in [input_wire1, input_wire2, input_wire3]] self.elgamal = elgamal self.pks = None def receive_pks(self, pks): self.pks = pks def transfer_messages(self): ciphers = [] for idx, (k0, k1) in enumerate(self.key_set): pk0, pk1 = self.pks[idx] c0 = self.elgamal.enc(k0, pk0) c1 = self.elgamal.enc(k1, pk1) ciphers.append((c0, c1)) return self.circuit, self.own_keys, ciphers def run(da, db, ds, ra, rb, rs): p = gen_prime(256) q = 2 * p + 1 g = SystemRandom().randint(2, q) elgamal = ElGamal(g, q, p) alice = Alice(ra=ra, rb=rb, rs=rs, elgamal=elgamal) bob = Bob(da=da, db=db, ds=ds, elgamal=elgamal) bob.receive_pks(alice.send_pks()) pls = alice.retrieve(*bob.transfer_messages()) return pls def main(): green = 0 red = 0 for i, recipient in enumerate(BloodType): for j, donor in enumerate(BloodType): z = run(*donor.value, *recipient.value) lookup = blood_cell_compatibility_lookup(recipient, donor) if lookup == z: green += 1 else: print(f"'{BloodType(donor).name} -> {BloodType(recipient).name}' should be {lookup}.") red += 1 print("Green:", green) print("Red :", red) # run(donor=BloodType.A_NEGATIVE, recipient=BloodType.B_POSITIVE)