#!/usr/bin/env sage import os import glob import subprocess import ctypes from ctypes import c_int,c_char_p,c_ulonglong,POINTER,byref,create_string_buffer import hashlib import random def hash(seed): h = hashlib.sha256(); h.update(seed); return h.digest() proof.all(False) # ----- build test vectors, and compile library for future use subprocess.run(['make']) subprocess.run(['./PQCgenKAT_sign']) subprocess.run('gcc -shared -o libkaz.so kaz_api.c sign.c rng.c sha256.c -fPIC -lcrypto -lgmp',shell=True) # ----- extract system parameters from kaz_api.h def api_num(macro): with open('macro_%s.c'%macro,'w') as f: f.write('#include \n') f.write('#include "kaz_api.h"\n') if 'SP' in macro and macro != 'SP_BETA': f.write('int main() { puts(KAZ_DS_%s); return 0; }\n' % macro) else: f.write('int main() { printf("%%d\\n",KAZ_DS_%s); return 0; }\n' % macro) subprocess.run('gcc -o macro_%s macro_%s.c'%(macro,macro),shell=True,check=True) x = subprocess.run(['./macro_%s'%macro],check=True,text=True,capture_output=True) x = x.stdout x = x.strip().replace(' ','') return int(x) N = api_num('SP_N') Q = api_num('SP_Q') q = api_num('SP_q') G = api_num('SP_G') R = api_num('SP_R') BETA = api_num('SP_BETA') Gg = api_num('SP_Gg') GRg = api_num('SP_GRg') phiq2o2 = api_num('SP_PHIQ2') ALPHABYTES = api_num('ALPHABYTES') V1BYTES = api_num('V1BYTES') V2BYTES = api_num('V2BYTES') S1BYTES = api_num('S1BYTES') S2BYTES = api_num('S2BYTES') macro = 'phiq2beta2' with open('macro_%s.c'%macro,'w') as f: f.write('#include \n') f.write('#include \n') f.write('#include "kaz_api.h"\n') f.write('int main() {\n') f.write(' mpz_t PHIQ2,PHIQ2BETA2;\n') f.write(' int BETA=KAZ_DS_SP_BETA;\n') f.write(' mpz_inits(PHIQ2,PHIQ2BETA2);\n') f.write(' mpz_set_str(PHIQ2, KAZ_DS_SP_PHIQ2, 10);\n') f.write(' mpz_divexact_ui(PHIQ2BETA2, PHIQ2, BETA*BETA);\n') f.write(' gmp_printf("%Zd\\n",PHIQ2BETA2);\n') f.write(' return 0;\n') f.write('}\n') subprocess.run('gcc -o macro_%s macro_%s.c -lgmp'%(macro,macro),shell=True,check=True) x = subprocess.run(['./macro_%s'%macro],check=True,text=True,capture_output=True) x = x.stdout x = x.strip().replace(' ','') PHIQ2BETA2 = int(x) # ----- derived parameters PHIN = euler_phi(N) ORDERG = Mod(G,N).multiplicative_order() PHIORDERG = euler_phi(ORDERG) ORDERR = Mod(R,PHIN).multiplicative_order() phiphin = euler_phi(PHIN) realRorder = Mod(R,ORDERG).multiplicative_order() assert realRorder == GRg assert PHIN == euler_phi(N) assert ORDERG == Mod(G,N).multiplicative_order() assert ORDERR == Mod(R,PHIN).multiplicative_order() beta = max(p for p,e in factor(GRg)) # ----- verification and forgery def decode(b): b = bytearray(b) while b[:1] == b'\0': b = b[1:] return sum(c<<(8*i) for i,c in enumerate(reversed(b))) def encode(i,targetbytes): result = bytearray() while i > 0: result = bytearray([i%256])+result i >>= 8 while len(result) < targetbytes: result = bytearray([0])+result assert len(result) == targetbytes return bytes(result) assert decode(encode(31415,5)) == 31415 def open_extract(sm,pk): v1,pk = pk[:V1BYTES],pk[V1BYTES:] v2,pk = pk[:V2BYTES],pk[V2BYTES:] assert len(v1) == V1BYTES assert len(v2) == V2BYTES assert len(pk) == 0 s1,sm = sm[:S1BYTES],sm[S1BYTES:] s2,m = sm[:S2BYTES],sm[S2BYTES:] assert len(s1) == S1BYTES assert len(s2) == S2BYTES h = hash(m) v1,v2,s1,s2,h = map(decode,(v1,v2,s1,s2,h)) gs1r = gcd(s1,GRg) aF = v1%GRg q2 = q*q GRgq2 = GRg*q2 try: s1inverse = Mod(s1,GRgq2)^(-1) except: print('warning: non-invertible s1') s1inverse = 0 assert gs1r*(h+Mod(v1,GRgq2)^s1)*s1inverse != s2 assert gs1r*(h+Mod(aF,GRgq2)^s1)*s1inverse != s2 assert Mod(s1*s2-gs1r*h,q) == gs1r*Mod(v1,q)^s1 # PHIQ2BETA2 = ZZ(phiq2o2 / (BETA*BETA)) try: w7 = Mod(s1,PHIQ2BETA2)^(-1) except: w7 = 0 w7 = ZZ(w7) q0 = ZZ((q-1)/Q) w8first = Mod(s1*s2-gs1r*h,q2) w8 = (w8first/gs1r)^w7 ex = ZZ(w8) w8 ^= q0 w8 = ZZ(w8) w8alloc = 0 while 256^w8alloc <= w8: w8alloc += 1 w8str = encode(w8,w8alloc) assert decode(hash(w8str)) == v2 assert Mod(G,N)^(Mod(R,Gg)^(s1*s2)) == Mod(G,N)^(Mod(R,ORDERG)^(gs1r*(h+Mod(v1,GRg)^s1))) return m,ex def sage_open(sm,pk): return open_extract(sm,pk)[0] def extract(sm,pk): return open_extract(sm,pk)[1] libkaz = ctypes.CDLL(f'{os.getcwd()}/libkaz.so') libkaz_open = libkaz.crypto_sign_open libkaz_open.argtypes = c_char_p,POINTER(c_ulonglong),c_char_p,c_ulonglong,c_char_p libkaz_open.restype = c_int def reference_open(sm,pk): smlen = c_ulonglong(len(sm)) m = create_string_buffer(len(sm)) mlen = c_ulonglong(0) pk = create_string_buffer(pk) assert libkaz_open(m,byref(mlen),sm,smlen,pk) == 0 return m.raw[:mlen.value] def forge(m,pk,ex): h = hash(m) v1,pk = pk[:V1BYTES],pk[V1BYTES:] v2,pk = pk[:V2BYTES],pk[V2BYTES:] v1,v2,h = map(decode,(v1,v2,h)) q2 = q*q GRgq2 = GRg*q2 alpha = CRT(v1,ex,GRg,q2) while True: r0 = random.randrange(2**256) r = r0*BETA s1 = r%GRgq2 gs1 = gcd(r,GRg) if gcd(ZZ(r/gs1),GRgq2) != 1: continue if gcd(ZZ(s1/gs1),GRgq2) != 1: continue if s1%BETA != 0: continue if gcd(s1,PHIQ2BETA2) != 1: continue break s2 = (h+Mod(alpha,GRgq2)^s1)/ZZ(s1/gs1) s1,s2 = map(ZZ,(s1,s2)) return encode(s1,S1BYTES)+encode(s2,S2BYTES)+m def forgerytest(pk,ex): for newmsg in b'forged message',b'another forged message': while True: sm = forge(newmsg,pk,ex) try: assert newmsg == sage_open(sm,pk) assert newmsg == reference_open(sm,pk) break except: pass print(f' newmsg: {newmsg}') print(f' forged sm: {sm.hex()}') assert newmsg == sage_open(sm,pk) assert newmsg == reference_open(sm,pk) print(' forgery passes verification') # ----- try public keys and messages from *.rsp totalskipped = 0 totalforged = 0 for rsp in glob.glob('*.rsp'): print(f'scanning {rsp}') count,msg,pk,sm = 'undefined',None,None,None with open(rsp) as f: for line in f: line = line.split() if line[:2] == ['count','=']: count = line[2]; msg,pk,sm = None,None,None if line[:2] == ['msg','=']: msg = line[2] if line[:2] == ['pk','=']: pk = line[2] if line[:2] == ['sm','=']: sm = line[2] if len(line) == 0 and count != 'undefined': print(f'KAT {count}:') try: msg = bytes.fromhex(msg) pk = bytes.fromhex(pk) sm = bytes.fromhex(sm) assert len(pk) == V1BYTES+V2BYTES assert len(sm) >= S1BYTES+S2BYTES for loop in range(1): assert reference_open(sm,pk) == msg assert sage_open(sm,pk) == msg ex = extract(sm,pk) except: print(' unable to verify KAT, skipping forgery') totalskipped += 1 continue forgerytest(pk,ex) totalforged += 1 print('KATs with forgery attempts skipped:',totalskipped) print('KATs with successful forgery attempts:',totalforged)