https://infossm.github.io/blog/2023/09/16/Brakedown/

 

Brakedown Overview

이 내용은 https://eprint.iacr.org/2021/1043 의 요약입니다. 이 논문의 목표는 Linear Code를 기반으로 한 Linear-Time PCS를 준비하고 이를 Spartan에 적용하여 Linear-Time Field-Agnostic SNARK를 얻는 것입니다. Spartan 계

infossm.github.io

 

'Cryptography' 카테고리의 다른 글

Folding Part 1: ProtoStar  (0) 2023.12.01
Multilinear PCS from Univariate PCS  (0) 2023.12.01
Monolith Hash Function  (0) 2023.09.30
[Axiom OS Project] Implementing Poseidon2 & AES-ECB for Verifiable Encryption  (0) 2023.06.14
ZK Applications  (0) 2023.03.03

https://infossm.github.io/blog/2023/07/14/Monolith/

 

Hash Functions Monolith for ZK Applications: May the Speed of SHA-3 be With You

이 내용은 https://eprint.iacr.org/2023/1025.pdf 의 요약입니다. ZK Friendly Hash Function의 필요성 해시함수는 굉장히 많은 곳에서 사용되고 있습니다. 그런만큼 ZKP 상에서도 해시함수의 계산에 대한 증명을

infossm.github.io

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/sage
import random
import hashlib
import os
import signal 
 
signal.alarm(1800)
 
def PoW():
    prefix = os.urandom(8)
    print(prefix.hex())
    answer = bytes.fromhex(input().strip())
    assert len(answer) == 24
    result = hashlib.sha256(prefix + answer).digest()
    assert result[:3== b"\x00\x00\x00"
 
= PolynomialRing(ZZ, 'x')
= P.gen()
 
def convolution(n, f, g):
    return (f * g) % (x ** n - 1)
 
def balance_mod(f, q):
    tt = f.coefficients(sparse = False)
    ret = 0
    for i in range(len(tt)):
        cc = int((tt[i] + q // 2) % q) - q // 2
        ret += cc * (x ** i)
    return ret
 
def random_poly(n, v1, v2):
    ret = v1 * [1+ v2 * [-1+ (n - v1 - v2) * [0]
    random.shuffle(ret)
    return P(ret)
 
def invert_prime(n, f, p):
    T = P.change_ring(GF(p)).quotient(x ** n - 1)
    ret = P(lift(1 / T(f)))
    return balance_mod(ret, 3)
 
def pad(n, arr):
    while len(arr) < n:
        arr.append(0)
    return arr
 
def encode(n, arr):
    res = 0
    for i in range(n):
        assert -1 <= arr[i] <= 1
        res += (arr[i] + 1* (3 ** i)
    return res 
 
def task1(n, D):
    random.seed(int.from_bytes(os.urandom(32), "big"))
    f = random_poly(n, n // 3 + 1, n // 3)
    f3 = invert_prime(n, f, 3)
 
    random.seed(int.from_bytes(os.urandom(32), "big"))
    sel1 = random.sample(range(n), D)
    random.seed(int.from_bytes(os.urandom(32), "big"))
    sel2 = random.sample(range(n), D)
 
    coef_original = pad(n, f.coefficients(sparse = False))
    coef_inverse = pad(n, f3.coefficients(sparse = False))
 
    for i in range(D):
        coef_original[sel1[i]] = 0
        coef_inverse[sel2[i]] = 0
    
    print(sel1)
    print(sel2)
    print(encode(n, coef_original))
    print(encode(n, coef_inverse))
 
    assert int(input()) == encode(n, pad(n, f.coefficients(sparse = False)))
    assert int(input()) == encode(n, pad(n, f3.coefficients(sparse = False)))
 
def task2(n, D):
    random.seed(int.from_bytes(os.urandom(32), "big"))
    f = random_poly(n, n // 3 + 1, n // 3)
    f3 = invert_prime(n, f, 3)
    
    seed = int(input())
    random.seed(seed)
 
    sel1 = random.sample(range(n), D)
    sel2 = random.sample(range(n), D)
 
    coef_original = pad(n, f.coefficients(sparse = False))
    coef_inverse = pad(n, f3.coefficients(sparse = False))
 
    for i in range(D):
        coef_original[sel1[i]] = 0
        coef_inverse[sel2[i]] = 0
    
    print(sel1)
    print(sel2)
    print(encode(n, coef_original))
    print(encode(n, coef_inverse))
 
    assert int(input()) == encode(n, pad(n, f.coefficients(sparse = False)))
    assert int(input()) == encode(n, pad(n, f3.coefficients(sparse = False)))
 
PoW()
for _ in range(8):
    task1(241183)
for _ in range(8):
    task2(85012125)
 
flag = open("flag.txt""r").read()
print(flag)
cs

 

We are given two polynomials $f, f_v$ such that $f \cdot f_v \equiv 1 \pmod{x^n - 1}$, but some $D$ of the coefficients are erased. We have to recover $f, f_v$ completely, in a relatively fast and reliable fashion. The erasure positions are also given by the server.

 

For the first task, $(n, D) = (2411, 83)$ and the erasure positions are completely random. 

For the second task, $(n, D) = (8501, 2125)$ and the erasure positions can be controlled by a user provided seed. 

 

Task 1

By setting a variable for each erased coefficient, we will have a system of $n$ quadratic equations over $2D$ variables in $\mathbb{F}_3$. However, the interesting part is that some of the quadratic equations are actually just linear. For example, if we denote $S_1$ and $S_2$ as the set of erased coefficient's degree in $f$ and $f_v$ respectively, we can see that the equation arising from computing the coefficient of $x^k$ in $f \cdot f_v \pmod{x^n - 1}$ will be simply linear if there are no $u \in S_1$ and $v \in S_2$ such that $u + v \equiv k \pmod{n}$. 

 

By collecting these equations and solving the linear system, we will be closer to finding the solutions for the $2D$ variables.

However, after implementing this you can see that there will be a nontrivial kernel, of size around 40 to 50. 

 

This can be resolved in two ways. 

  • The author's intended solution is to modify the given system as a system of $n$ quadratic equations over $K$ variables, where $K$ is the size of the kernel. This can be done simply by expressing the solution set of the $2D$ variables as a single solution added with a vector in the span of the computed kernel basis. As $K$ is much smaller than $2D$, we can actually solve this quadratic equation system by linearization. In other words, we can regard all quadratic terms as a separate linear variable, and solve the linear system over $\mathcal{O}(K^2)$ variables. This fails if $K$ is large, but such probability is small enough so that you can just try repeatedly.  
  • soon-haari's solution works by selecting variables so that fixing it will add as many linear equations as possible, then brute forcing them. Apparently, brute forcing around 3 to 7 variables makes it sufficient to solve everything with a linear system. This was considered by the author as well, but was considered to be of similar difficulty. Congratulations to soon-haari for the solve!

 

Task 2

From solving task 1, it should be clear that the goal should be to create as many linear equations as possible, and the best way to do it is by making the erased coefficients consecutive in their positions. Note that $D = n/4$. Now how do we do that?

 

Looking at the sample implementation, we can see that the random sampling works by 

  • selecting a random number below $n - i$
  • taking the value at that index
  • swapping with the value at position $n - i - 1$ so it's not selected again
1
2
3
4
5
6
7
8
if n <= setsize:
    # An n-length list is smaller than a k-length set.
    # Invariant:  non-selected at pool[0 : n-i]
    pool = list(population)
    for i in range(k):
        j = randbelow(n - i)
        result[i] = pool[j]
        pool[j] = pool[n - i - 1]  # move non-selected item into vacancy
cs

 

The first idea is that our consecutive selections should be between $3n/4$ and $n$ - this is because if we try to pick everything from the front, the whole swapping process with the elements from the back makes everything very complicated. By picking everything at the back, the swapping process doesn't matter. Our goal is that for each $0 \le i < 2D$, the $i$th randbelow call should return a value $x$ such that $$n - D \le x < n - (i \pmod{D})$$ To do this efficiently, we need to minimize the number of bits we constrain from the randbelow results.

 

This can be done by finding $t, e$ such that $$n - D \le t \cdot 2^e < (t + 1) \cdot 2^e \le n - (i \pmod{D})$$ and maximizing $e$. Now, it suffices to constrain that the randbelow result is equal to $t$ after being shifted right $e$ bits. 

 

With this constraint in mind, finding the random seed is a relatively standard & well-known part. Check 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
from sage.all import * 
from pwn import * 
import os, time
from tqdm import tqdm
 
master_seed = 60024794038789808135493353697686616908493063225185396948889991711002019881012260998892909674537714109874172336760379200731611344495379266686257170624293028615555926994680615436128867836845735901814299032623034279895233846538952693992950207880205125875448087855836772541821353895968629329852908028951482292046429634872159757554409659745842900181726502284855318484066778535471595292549152348723264514710398644809539021921578114599851865111153331428784733149637717776626956886869295212600036738540930509437369722973201969837169465390106295804865432397396740692302797829400696350801197455057637221980617059842112641232969233787118965053967785222933031549323453023915303961246528225916509962869051725570255571786695958227355527242709785066976431806251198382777530714413831824130431722178073034419423922524991559359620024288713328086359403414933794337286671654435418265857597949112504278406015704720441696610732129686376522250690959070825600070450708353622473071494975779294350695405558303914530098365686912502766894493961489506186331053557290449360412540423892253756102487513806029093829893795865411528046974478581984437329446430455654183382959047805022370700975782191280373163986285947944114008525799059163331940037604447345980775790916325310887097685957858367094306509865450306900924968543630281114344776628409464077904562522114091218110199532185567329195460558546578726337659844292296095730138395687853467464686825377362140525018332064336498037333745669506410885086785171181512953442077146559888185321193265161347864107827758406549956890513622009827054983736589580372571827584944816635338237441722836404122046830295098128877031939911864423063671844869052268187006757023952046083313268542630040867977783706755657741531230633393603995284157154279750477826730638935203116032077204427304499208757288061746193707867350153416770764821598920213730711114965590743331727312353459097800387002408397658122900984840808246371821451013961389731787227009157398692371099660858182271036744448166340400895784346173287333244495255001791730081569336839520662381405144290409455974037142516983228317936699727842606081642240578844883094341186246580921199983915842081314638611291380152731368056730650469228559623662052531729887982086548223154333361363334651179104285034483322425275792671395416574805801066174116333550082050898468922185218775818163991247447213512072133729905831026979204387449740237107250035767499446745374366910625527910665062130250560706824973131871130316743119025554746764821434156674317934199946146598426673815806872877856328467376086085059620030855489151545303043307029412045887534894425887717583117437877318171620007894084301274965980741305242376340156346230267114570705189461331266648670135971622540649186222606301719344988665012528312715455324776650769063545393847491422947171743363094737848164203527815500665056795598783088145237029472339734427963849835022178538693065735805207641271044631695148529356782873975936571285135245429745146108036491430498162500763656245871098506267158087440237750159751664863211065697979948218429442546498813889729910223944717520161604727418685340138991981985619225849416335059973922841803883389393526433881970295102025473967271242147333800098274311358630381328801550779263853582690762954190377450436922016091335950930865947431088321016971794583463543506805777276891066135310863296397756985031896218023594355682462504506734940500365812292755553197773898889660699577561530145492927609494642174649101705826746672666183294279559291981779160964522297534960255063618332466976614501183228998682729197968575051178082566996628571473801354118883690941507931744401374987244372715844496021747653016503181871565513365490311937754665491041784007233799634349717657539855346128371782526748645254711640975576359645995419371902147356469195156372805778603944262661098290749863028288989422941985200606780481844930898792872531924601622394832468102957298372210031427732699319934781107988301247600593511823195668971916883951609735878765097692081037005879736118141874324737540398705286288988681436765122193427285686784758267811646782109984193537493517035025336405137273109568869777622232503118830613665850700496192371337485618910625319419657584661501620375974836832411786322507731977027575625783250266414378787148900895983949856500564564779297830591427575079297862649141244375485805801593463197364935921063290595220319427752540805812417257433276354627419044442018101969592668555978348415720913743417010710043519932066341337349002366227396881340998080393401376080508879416483875634105463190312058370838700500096751599290858489939320163312460446194287943741486626465044167916764309453265867227749001640757524516320238197501217997176854839242256247737322846205881431327708198117910578431481983383453719767475910167265170827225409678070198207927333097671611616080671517652067086557349582167623754660313732259984661354739151828000954072261669966765373367625577786880673420597700680012679943342785274712575950904993112141660047795983847702513444872812047642230003916360543076277750525218857399197096056760230787156929311820632995878912539801079019446915728253563406722059879474176881696819163446675139576322957663476859383520766444528469597928076769925748138660178853642739598139933326858259371267626938125840719239557084994647877357315029370116939357392564073911953692514212255988095253356674531690329478518006539330700869866952658773265849967085817409436103360317583749880904176992380298324586178020697936498010899061947345939204960967177164006648986295804937160615262813770879693445967104259390803516524422802908637470251797969593148039038477540080340954624195569529253847629154075274257199861147125055216614817003693519884916983205377401473762967508445819467333797312883675275497717519827470702088770584690428550650158813144401296875683289360897009594467643245700773535411913797121904382044029865707340324833895400600974009981197992647586507846704616021687940777933174900551897716451821657379278722485819954967117168842871510763631682184488735137792120697666469419973009451057094047507237995094164031163468702
 
= PolynomialRing(ZZ, 'x')
= P.gen()
 
# r = process(['sage', '../prob/for_organizer/task.sage'])
= remote("52.79.59.27"17776)
 
def solvePoW():
    PoW_start = time.time()
    prefix = bytes.fromhex(r.recvline().strip().decode())
    cnt = 0
    while True:
        cnt += 1
        if cnt % 1000000 == 0:
            print(cnt)
        ret = os.urandom(24)
        result = hashlib.sha256(prefix + ret).digest()
        if result[:3== b"\x00\x00\x00":
            r.sendline(ret.hex())
            break
    PoW_end = time.time()
    print("PoW", PoW_end - PoW_start)
 
def balance_mod(f, q):
    tt = f.coefficients(sparse = False)
    ret = 0
    for i in range(len(tt)):
        cc = int((tt[i] + q // 2) % q) - q // 2
        ret += cc * (x ** i)
    return ret
 
def pad(n, arr):
    while len(arr) < n:
        arr.append(0)
    return arr
 
def encode(n, arr):
    res = 0
    for i in range(n):
        assert -1 <= arr[i] <= 1
        res += (arr[i] + 1* (3 ** i)
    return res 
 
def decode(n, v):
    ret = [0* n 
    for i in range(n):
        ret[i] = v % 3 - 1
        v = v // 3 
    return ret
 
def read_input(n, D):
    s1 = r.recvline()
    s2 = r.recvline()
    if b"Traceback" in s1:
        for _ in range(20):
            print(r.recvline())
    sel1 = eval(s1)
    sel2 = eval(s2)
    coef_original = decode(n, int(r.recvline().strip()))
    coef_inverse = decode(n, int(r.recvline().strip()))
    return sel1, sel2, coef_original, coef_inverse
 
def solve_init(n, D, sel1, sel2, coef_original, coef_inverse):
    mat, vec = [], []
    isok = [1 for _ in range(n)]
    for i in range(D):
        for j in range(D):
            isok[(sel1[i] + sel2[j]) % n] = 0
    
    idxs1 = [-1* n 
    idxs2 = [-1* n 
    for i in range(D):
        idxs1[sel1[i]] = i 
        idxs2[sel2[i]] = i
    
    for i in tqdm(range(n)):
        if isok[i] == 1:
            coefs = [0* (2 * D)
            val = 0
            if i == 0:
                val = 1
            for j in range(n):
                if idxs1[j] != -1:
                    coefs[idxs1[j]] = coef_inverse[(i - j) % n]
                if idxs2[(i - j) % n] != -1:
                    coefs[idxs2[(i - j) % n] + D] = coef_original[j]
                if idxs1[j] == -1 and idxs2[(i - j) % n] == -1:
                    val -= coef_original[j] * coef_inverse[(i - j) % n]
            mat.append(coefs)
            vec.append(val % 3)
    
    M = Matrix(GF(3), mat)
    v = vector(GF(3), vec)
    sol = M.solve_right(v)
 
    kernel = M.right_kernel().basis()
    return sol, kernel 
 
def solve_task1(n, D):
    sel1, sel2, coef_original, coef_inverse = read_input(n, D)
 
    sol, kernel = solve_init(n, D, sel1, sel2, coef_original, coef_inverse)
 
    idxs1 = [-1* n 
    idxs2 = [-1* n 
    for i in range(D):
        idxs1[sel1[i]] = i 
        idxs2[sel2[i]] = i
 
    fvar = len(kernel)
    print(fvar)
    tot = fvar + (fvar * (fvar - 1)) // 2 + fvar
 
    idxs = [[0 for _ in range(fvar)] for _ in range(fvar)]
    for i in range(fvar):
        idxs[i][i] = i
    
    cur = fvar 
    for i in range(fvar):
        for j in range(i + 1, fvar):
            idxs[i][j] = cur
            cur += 1
    
    def single(x):
        return fvar + fvar * (fvar - 1// 2 + x
 
    def wow(x1, x2):
        if x1 == x2:
            return x1
        else:
            if x1 > x2:
                x1, x2 = x2, x1
            return idxs[x1][x2]
 
    mat = []
    vec = []
 
    for i in tqdm(range(n)):
        coefs = [0* tot
        val = 0
        if i == 0:
            val = 1
        for j in range(n):
            # [j] * [i - j]
            if idxs1[j] == -1 and idxs2[(i - j) % n] == -1:
                val -= coef_original[j] * coef_inverse[(i - j) % n]
            if idxs1[j] == -1 and idxs2[(i - j) % n] != -1:
                idx2 = idxs2[(i - j) % n]
                val -= coef_original[j] * sol[idx2 + D]
                for k in range(fvar):
                    coefs[single(k)] += coef_original[j] * kernel[k][idx2 + D]
            if idxs1[j] != -1 and idxs2[(i - j) % n] == -1:
                idx1 = idxs1[j]
                val -= coef_inverse[(i - j) % n] * sol[idx1]
                for k in range(fvar):
                    coefs[single(k)] += coef_inverse[(i - j) % n] * kernel[k][idx1]
            if idxs1[j] != -1 and idxs2[(i - j) % n] != -1:
                idx1 = idxs1[j]
                idx2 = idxs2[(i - j) % n]
                val -= sol[idx1] * sol[idx2 + D]
                for k in range(fvar):
                    coefs[single(k)] += sol[idx1] * kernel[k][idx2 + D]
                    coefs[single(k)] += sol[idx2 + D] * kernel[k][idx1]
                for k1 in range(fvar):
                    for k2 in range(fvar):
                        coefs[wow(k1, k2)] += kernel[k1][idx1] * kernel[k2][idx2 + D]
        mat.append(coefs)
        vec.append(val)
 
    M = Matrix(GF(3), mat)
    v = vector(GF(3), vec)
    final_sol = M.solve_right(v)
 
    fins = [0* (2 * D)
 
    for i in range(2 * D):
        fins[i] += sol[i]
    
    for i in range(2 * D):
        for j in range(fvar):
            fins[i] += final_sol[single(j)] * kernel[j][i]    
 
    recover_f = 0
    recover_f3 = 0
    for i in range(n):
        if i in sel1:
            recover_f += fins[sel1.index(i)] * (x ** i)
        else:
            recover_f += coef_original[i] * (x ** i)
    
    for i in range(n):
        if i in sel2:
            recover_f3 += fins[sel2.index(i) + D] * (x ** i)
        else:
            recover_f3 += coef_inverse[i] * (x ** i)
    
    recover_f = balance_mod(recover_f, 3)
    recover_f3 = balance_mod(recover_f3, 3)
 
    r.sendline(str(encode(n, pad(n, recover_f.coefficients(sparse = False)))))
    r.sendline(str(encode(n, pad(n, recover_f3.coefficients(sparse = False)))))
 
def solve_task2(n, D):
    r.sendline(str(master_seed))
    sel1, sel2, coef_original, coef_inverse = read_input(n, D)
 
    for i in range(D):
        assert n - n // 4 <= sel1[i] < n 
        assert n - n // 4 <= sel2[i] < n
 
    sol, kernel = solve_init(n, D, sel1, sel2, coef_original, coef_inverse)
 
    print("task2 kernel"len(kernel))
 
    recover_f = 0
    recover_f3 = 0
    for i in range(n):
        if i in sel1:
            recover_f += sol[sel1.index(i)] * (x ** i)
        else:
            recover_f += coef_original[i] * (x ** i)
    
    for i in range(n):
        if i in sel2:
            recover_f3 += sol[sel2.index(i) + D] * (x ** i)
        else:
            recover_f3 += coef_inverse[i] * (x ** i)
 
    recover_f = balance_mod(recover_f, 3)
    recover_f3 = balance_mod(recover_f3, 3)
 
    r.sendline(str(encode(n, pad(n, recover_f.coefficients(sparse = False)))))
    r.sendline(str(encode(n, pad(n, recover_f3.coefficients(sparse = False)))))
 
st = time.time()
 
solvePoW()
 
for _ in tqdm(range(8)):
    solve_task1(241183)
 
for _ in tqdm(range(8)):
    solve_task2(85012125)
 
print(r.recvline())
 
en = time.time()
 
print(en - st)
cs

 

'CTF' 카테고리의 다른 글

ACSC 2024: Strange Machine (4 solves)  (0) 2024.03.31
Paradigm CTF 2023 2nd Place  (0) 2023.10.31
ACSC 2023 Writeups  (0) 2023.02.26
HackTM CTF Writeup  (0) 2023.02.22
BlackHat MEA Finals  (0) 2022.11.21

https://www.youtube.com/watch?v=8wsR7o0rOxU 

 

'Blockchain Security' 카테고리의 다른 글

Scroll's Security Measure Seminar  (0) 2023.10.25
Scroll zkEVM Audit Report  (0) 2023.10.17
A fun story on "Membership Proofs"  (0) 2022.12.07
DFX Finance Attack Overview  (0) 2022.11.16
CODEGATE 2022: A Survey on Price Oracle Attacks  (0) 2022.11.05

Introduction

I participated in the first cohort of the Axiom Open Source Program. After studying and being fascintated with the theory of ZK-friendly hashes, I decided that I will implement some of them for this program. My target for implementation was the newly (at the time) developed Poseidon2 hash function. After a while, I kept thinking about what to do next, whether to keep implementing more hashes or go in a different direction. At this point, my friend asked me about a interesting puzzle, and it went like this. 

Suppose a password-based key management system stores the user's key as $E(pw, K)$. Suppose the user now wants to change the password into $pw'$, so the storage should change to $E(pw', K)$. How should the system verify that this new value is still an encryption of $K$, without knowing $pw, pw', K$ at all?

 

This was a very interesting and real-world puzzle - and some search lead to the theory of verifiable encryption, where a certain property is proved over an encrypted plaintext. It's also clear that ZKP can give us a solution here. 

 

By allowing the system to store $Hash(K)$, we can change this problem to 

 Prove that the user knows $pw, K$ such that $Hash(K) = A$ and $E(pw, K) = B$ where $A, B$ are stored on the system.

 

Selecting the hash function as Poseidon2, I was left with selecting $E$ - and I decided for AES. For simplicity, I chose AES-ECB. 

I also decided that I will try to use pure halo2-lib as much as possible - this is because I already implemented Poseidon2 in halo2-lib at the time, and mixing vanilla halo2 with Axiom's halo2-lib is definitely not an easy task. 

 

Implementing Poseidon2 

To discuss the implementation aspects of Poseidon2, we need to first how Poseidon and Poseidon2 works. 

Roughly speaking, these two hash functions are based on a sponge-based construction, which means that the hash is based on a permutation. Poseidon hash has a width parameter $t$, and this means that the permutation is of $\mathbb{F}_p^t \rightarrow \mathbb{F}_p^t$. To design this permutation, Poseidon uses three types of layers - round constant addition, MDS matrix linear layer, and the SBOX layer. 

 

The round constant addition layer is straightforward - it simply adds a round constant to each element. 

The MDS matrix linear layer is also straightforward - it's a matrix multiplication. The "MDS" part is a description about the matrix which is needed for security analysis, but for implementation/understanding purposes it's not very important. 

The SBOX layer is $S(x) = x^\alpha$, where $\alpha$ is the minimum positive integer that $\gcd(\alpha, p - 1) = 1$. For BN254, we select $\alpha = 5$. 

 

The most interesting part of Poseidon is the difference of full rounds and partial rounds. The idea is that not all the rounds needs to have S-boxes to every element in the state. Instead, we can use partial rounds, which only uses the S-box for a single element in the state. By putting $R_f = R_F / 2$ full rounds, then $R_P$ partial rounds, then $R_f = R_F / 2$ full rounds, we can maintain security while saving the use of many S-boxes, leading to a more efficient hash function. The outline of this permutation is shown in a figure below. 

So what's the difference between Poseidon and Poseidon2? There are some subtle differences, but the main difference lies in the difference in the MDS matrix linear layer. The matrices are generated differently, for better native runtime and better costs in terms the ZKP. The matrix for the external full rounds and the internal partial rounds is also different. This permutation's layout is shown in the figure below.

As Poseidon is already implemented in Axiom's halo2-lib, all I needed to do was implement these differences. 

 

Grain LFSR & The Parameter Generation

The first part is the parameter generation algorithm. For Poseidon, this is implemented in halo2/primitives/poseidon

The parameters for the round constants or the matrix multiplication is generated based on Grain LFSR, and the initial values for this LFSR is with basic parameters such as $R_F, R_P$. Due to the different matrix format between Poseidon and Poseidon2, the generation algorithm itself is also quite different. I implemented the same algorithm from the Horizen Labs implementation, in their repository. 


There is one interesting part of the matrix generation algorithm that is common in both Poseidon and Poseidon2, which is the testing for the so called invariant subspace trails. The details for why this is important and how to test for it is beyond the scope of this blog post, but interested readers should dive into the literature of cryptanalysis on Poseidon. Anyways, what this means is that sometimes we need to re-generate the matrices if the generated matrix fails this check. However, implementing this in rust is quite time consuming as it deals with the computation of minimal polynomials of matrices. Therefore, I hardcoded the number of tries it takes to reach a matrix that satisfies the necessary checks. The unfortunate consequence is that this makes the implementation not fully generic, as it assumes that the field we are working on is over BN254. If there is a rust library for minimal polynomials of matrices, this can be written to be generic over any prime field. 

 

Implementation of Matrix Layers

While there are many optimization tricks in Poseidon, many of them are not relevant in Poseidon2. The main trick in Poseidon2 is that the matrices are designed to be easy to multiply, both in native computation and in the ZKP world. The overall implementation strategy was taken from the Horizen Labs implementation. These strategies are also described in the Poseidon2 paper's Appedix as well. 

The main operation used to implement these matrix layers is mul_add in the `GateInstructions`. 

 

 

Interesting Issues on the Horizen Labs Implementation

During the implementation process, I found some very interesting issues/points on the Horizen Labs implementation. This is quite awkward, as the Horizen Labs implementation is the reference implementation after all, and it is the implementation that is mentioned in the Poseidon2 paper itself. Therefore, the questions I will mention below may surve little to no purpose. With that in mind, here they are.

 

The first one is the Grain LFSR parameters. In the Poseidon parameter generation, the SBOX parameter is selected as 0 if the SBOX is of $x^\alpha$ with small positive $\alpha$ and 1 if the SBOX is $x^{-1}$. In the Poseidon2 parameter generation, it's the opposite - the SBOX parameter is 1.

 

The second one is in the plain implementation itself. In the Poseidon2 parameter generation, it's clear that the external matrix in the case $t = 4$ is simply $M_4$. However, in the plain implementation itself, it uses $2M_4$ as the external matrix. This is caused because the matrix for $t = 4t'$ with $t' \ge 2$ is with a circulant matrix $circ(2M_4, M_4, \cdots , M_4)$, and the implementation forgot to handle the case $t = 4$.

 

This issue is now fixed on the Poseidon2 repository.

 

Implementing AES

AES-ECB is, well, AES-ECB. If you look at some pure python implementations like this one, we see that we need to implement the SBOX, the byte xor operations, the "xtime" operation, and the byte range check. The remainder will be straightforward implementation. 

 

Implementing the SBOX

There are three ways to proceed here. 

- Use a lookup table of size $2^8$

- Create a SBOX table as a witness, then use Axiom's "select_from_idx"

- Implement the $GF(2^8)$ arithmetic and the affine transformation on $\mathbb{F}_2^8$

 

The third option seemed to be way too complex, so initial implementation used the second option. However, as you can expect, this is very inefficient, so a lookup table had to be used. The issue is that using pure Axiom halo2-lib and using lookup tables at the same time is quite non-trivial, especially if there are multiple tables to be used. To use a lookup table, I used the methodology from the RangeChip and the RangeCircuitBuilder - practically copy pasting everything except for the actual lookup table part. I added 0 and $256 \cdot (x + 1) + S(x)$ to the lookup table. Then, I could claim that $y = S(x)$ if $x, y$ are all within $[0, 256)$ and $256 \cdot (x + 1) + y \in T$. 

 

 

Implementing Byte XORs and "xtime"

There are two ways to continue here. 

- Again, use a lookup table

- Decompose everything as bits, then use bit XORs to implement byte operations

 

At first, I implemented in the second way. A bit xor can be implemented with a not gate and a select gate. 

 

However, I turned to using a lookup table in hopes of optimizing the circuit. I added $2^{24} + 2^{16} \cdot a + 2^8 \cdot b + a \oplus b$ to the lookup table - and with the assumption that $a, b, c \in [0, 256)$, $2^{24} + 2^{16} \cdot a + 2^8 \cdot  b + c \in T$ is enough to force $c = a \oplus b$. 

 

The same goes for the xtime operation. I added $2^{25} + 2^8 \cdot x + xtime(x)$ to the lookup table, and with the assumption that $a, b \in [0, 256)$, $2^{25} + 2^8 \cdot a + b \in T$ is enough to force $b = xtime(a)$. 

 

Implementing the Byte Range Check

There are two ways to proceed here.

- Use a lookup table

- Decompose the byte to 8 bits

 

The issue with the first approach is that we are currently using a single lookup table. Also, many checks with the lookup table so far is built on the assumption that every value is within $[0, 256)$. Therefore, performing byte checks with a lookup table (unless we somehow manage to use multiple lookup tables) leads to the danger of circular reasoning. I simply used the num_to_bits function of Axiom's halo2-lib to check that the values are within 8 bits. This is indeed quite a bit costly, and is the main further optimization that could be done. 

 

 

Final Benchmarks

Taken directly from the final presentation, we see that Poseidon2 is better in ZKP terms when the width $t$ is large. This is natural, as Poseidon2's dominant performance usually comes in native calculation, and the ZKP cost gets better when $t$ is large and the MDS matrices' special forms become more and more helpful in decreasing the cost. In a way, this benchmark agrees with the paper. 

 

In AES, we see that a single block costs around 66k cells in AES128, so around 6k per single AES round. 

If we can make multiple lookup tables possible, we can remove the 8 bit decomposition check, and get better performance.

'Cryptography' 카테고리의 다른 글

Brakedown Overview  (0) 2023.10.13
Monolith Hash Function  (0) 2023.09.30
ZK Applications  (0) 2023.03.03
Polynomials and Elliptic Curves in ZK  (0) 2023.02.27
A Hyperelliptic Curve Story  (0) 2023.02.22

https://github.com/rkm0959/rkm0959_presents/blob/main/ZKApplications.pdf

 

GitHub - rkm0959/rkm0959_presents: Presentations by rkm0959

Presentations by rkm0959. Contribute to rkm0959/rkm0959_presents development by creating an account on GitHub.

github.com

 

https://github.com/rkm0959/rkm0959_presents/blob/main/Polynomials_EllipticCurve.pdf

 

GitHub - rkm0959/rkm0959_presents: Presentations by rkm0959

Presentations by rkm0959. Contribute to rkm0959/rkm0959_presents development by creating an account on GitHub.

github.com

 

Finished 8th in ACSC Quals 2023. Solved all cryptography challenges and warmup rev. I had a alumni meetup with NEXON Youth Programming Challenge award winners, so I went home at like 9PM with a beer and sangria inside me.

(Here, sangria is not the recent folding scheme for PLONK from Geometry Research)

 

Kinda gave up on the whole ACSC thing until it was around midnight, then I saw that I could make it if I solve all cryptography challenges.

So I hurried to do so and finished around 5:30AM, then solved the warmup rev. Turns out that this was a good idea.

 

Merkle-Hellman

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
import random
import binascii
 
def egcd(a, b):
    if a == 0:
        return (b, 01)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)
 
def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m
 
def gcd(a, b): 
    if a == 0
        return b 
    return gcd(b % a, a) 
 
flag = open("flag.txt","rb").read()
# Generate superincreasing sequence
= [random.randint(1,256)]
= w[0]
for i in range(6):
    num = random.randint(s+1,s+256)
    w.append(num)
    s += num
 
# Generate private key
total = sum(w)
= random.randint(total+1,total+256)
= 0
while gcd(r,q) != 1:
    r = random.randint(100, q)
 
# Calculate public key
= []
for i in w:
    b.append((i * r) % q)
 
# Encrypting
= []
for f in flag:
    s = 0
    for i in range(7):
        if f & (64>>i):
            s += b[i]
    c.append(s)
 
print(f"Public Key = {b}")
print(f"Private Key = {w,q}")
print(f"Ciphertext = {c}")
 
# Output:
# Public Key = [7352, 2356, 7579, 19235, 1944, 14029, 1084]
# Private Key = ([184, 332, 713, 1255, 2688, 5243, 10448], 20910)
# Ciphertext = [8436, 22465, 30044, 22465, 51635, 10380, 11879, 50551, 35250, 51223, 14931, 25048, 7352, 50551, 37606, 39550]
cs

 

This is a knapsack-based cryptosystem, and we know practically everything here. Just decrypt it.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from sage.all import *
from Crypto.Util.number import inverse, long_to_bytes, bytes_to_long 
 
pk = [735223567579192351944140291084]
sk = [18433271312552688524310448]
= 20910
ctxt = [843622465300442246551635103801187950551352505122314931250487352505513760639550]
 
 
df = sk[2* inverse(pk[2], q)
 
res = b""
 
for c in ctxt:
    c = (c * df) % q 
    f = 0
    for i in range(6-1-1):
        if c >= sk[i]:
            f += (1 << (6 - i))
            c -= sk[i]
    res += bytes([f])
 
print(res)
 
cs

 

 

Check Number 63

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from Crypto.Util.number import *
import gmpy2
from flag import *
 
= open("output.txt","w")
 
f.write(f"n = {n}\n")
 
while e < 66173:
  d = inverse(e,(p-1)*(q-1))
  check_number = (e*- 1// ( (p-1)*(q-1) )
  f.write(f"{e}:{check_number}\n")
  assert (e*- 1) % ( (p-1)*(q-1) ) == 0
  e = gmpy2.next_prime(e)
  
f.close()
 
 
cs

 

We have $n$ alongside 63 pairs of $(e, k)$ where $$ed = k \phi(n) + 1$$ The goal is to factor $n$. First, the central idea is to note that $$k \phi(n) + 1 \equiv 0 \pmod{e}$$ so we recover $$\phi(n) \equiv - k^{-1} \pmod{e}$$ Combined with CRT, this gives us $$\phi(n) \pmod{ \prod e}$$ which is around $63 * 16 = 1008$ bits of information. Meanwhile, we since $$\phi(n) = n + 1 - (p + q)$$ we have a 1025 bit range for $\phi(n)$. We can easily brute-force all possible values of $\phi(n)$, then recover $p, q$ from $n, \phi(n)$. 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from sage.all import *
from Crypto.Util.number import inverse, long_to_bytes, bytes_to_long 
 
# ed = k phi(n) + 1
# k phi(n) + 1 == 0 mod e
# phi(n) == -k^-1 mod e
 
= # given number here
 
def iroot(v, r):
    return Integer(v).nth_root(r, truncate_mode=False)
 
tt = open("output.txt""r")
tt = tt.readlines()
 
vals = []
mods = []
 
for l in tt:
    wow = l.split(":")
    e = int(wow[0])
    res = int(wow[1])
    vals.append(int(e - inverse(res, e)))
    mods.append(e)
 
cc = int(prod(mods))
 
phi_cc = int(CRT(vals, mods))
 
for l in tt:
    wow = l.split(":")
    e = int(wow[0])
    res = int(wow[1])
    assert (phi_cc * res + 1) % e == 0
 
from tqdm import tqdm
 
for i in tqdm(range(1 << 20)):
    tot = int((n - phi_cc + 1) % cc)
    tot += i * cc
    phi = n - tot + 1
    assert phi % cc == phi_cc 
    try:
        p = (tot - iroot(tot * tot - 4 * n, 2)) // 2
        q = tot - p
        assert p < q 
        assert p * q == n
        print(p, q)
        from hashlib import sha512
        flag = "ACSC{" + sha512( f"{p}{q}".encode() ).hexdigest() + "}" 
        print(flag)
    except:
        pass
 
 
 
 
 
 
cs

 

 

Dual Signature Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from hashlib import sha256
from Crypto.Util.number import getPrime, isPrime, getRandomNBitInteger, inverse
 
 
flag = os.environ.get("FLAG""neko{cat_are_the_most_powerful_beings_in_fact}")
 
 
def h(m: bytes) -> int:
    return int(sha256(m).hexdigest(), 16)
 
 
def gen_prime():
    while True:
        q = getPrime(520)
        p = 2*+ 1
        if isPrime(p):
            return p, q
 
 
p1, q1 = gen_prime()
p2, q2 = gen_prime()
 
if q1 > q2:
    (p1, q1), (p2, q2) = (p2, q2), (p1, q1)
 
= int((os.urandom(512 // 8 - len(flag) - 1+ flag.encode()).hex(), 16)
= 4
y1 = pow(g, x, p1)
y2 = pow(g, x, p2)
 
 
def sign(m: bytes):
    z = h(m)
    k = getRandomNBitInteger(512)
    r1 = pow(g, k, p1)
    r2 = pow(g, k, p2)
    s1 = inverse(k, q1) * (z + r1*x) % q1
    s2 = inverse(k, q2) * (z + r2*x) % q2
 
    return (r1, s1), (r2, s2)
 
 
def verify(m: bytes, sig1, sig2):
    z = h(m)
    r1, s1 = sig1
    r2, s2 = sig2
 
    s1inv = inverse(s1, q1)
    s2inv = inverse(s2, q2)
    gk1 = pow(g, s1inv*z, p1) * pow(y1, s1inv*r1, p1) % p1
    gk2 = pow(g, s2inv*z, p2) * pow(y2, s2inv*r2, p2) % p2
 
    return r1 == gk1 and r2 == gk2
 
 
= b"omochi mochimochi mochimochi omochi"
sig1, sig2 = sign(m)
 
print(f"g = {g}")
print(f"p1, p2 = {p1}, {p2}")
print(f"y1, y2 = {y1}, {y2}")
 
print(f"m = {m}")
print(f"r1, s1 = {sig1}")
print(f"r2, s2 = {sig2}")
 
cs

 

I overcomplicated this problem way too much - the easier solution is combining the two signature schemes via CRT then LLL.

I tried some straightforward lattices without the CRT idea, but it didn't give me the answer. Here's my solution.

 

Start with the equations $$ks_1 = z + r_1x + c_1q_1, \quad ks_2 = z + r_2x + c_2q_2$$ where $c_1, c_2$ each have absolute values of at most something like $2^{515}$. We'll cancel out the $k$ in the equations to get $$s_2(z + r_1x + c_1q_1) = s_1(z+ r_2x + c_2q_2)$$ or $$(s_2r_1 - s_1r_2)x = (s_1 - s_2)z + (s_1q_2) c_2 - (s_2q_1) c_1$$ This gives us a system - we want $$-2^{515} \le c_1, c_2 \le 2^{515}$$ as well as the linear equation $$(s_1q_2) c_2 \equiv (s_2q_1) c_1 + (s_2 - s_1)z \pmod{s_2r_1 - s_1r_2}$$ While there are some GCD issues like $\gcd(s_1q_2, s_2r_1 - s_1r_2) = 2 \neq 1$, in essence this is the same type of problem as $$S \le c_1 \le E, \quad L \le Ax + B \bmod{M} \le R$$ which is the exact task that the special case of my CVP repository solves. After getting $c_1, c_2$, the rest is easy linear equation. 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
def ceil(n, m): # returns ceil(n/m)
    return (n + m - 1// m
 
def is_inside(L, R, M, val): # is L <= val <= R in mod M context?
    if L <= R:
        return L <= val <= R
    else:
        R += M
        if L <= val <= R:
            return True
        if L <= val + M <= R:
            return True 
        return False
 
def optf(A, M, L, R): # minimum nonnegative x s.t. L <= Ax mod M <= R
    if L == 0:
        return 0
    if 2 * A > M:
        L, R = R, L
        A, L, R = M - A, M - L, M - R
    cc_1 = ceil(L, A)
    if A * cc_1 <= R:
        return cc_1
    cc_2 = optf(A - M % A, A, L % A, R % A)
    return ceil(L + M * cc_2, A)
 
# check if L <= Ax (mod M) <= R has a solution
def sol_ex(A, M, L, R):
    if L == 0 or L > R:
        return True
    g = GCD(A, M)
    if (L - 1// g == R // g:
        return False
    return True
 
## find all solutions for L <= Ax mod M <= R, S <= x <= E:
def solve(A, M, L, R, S, E):
    # this is for estimate only : if very large, might be a bad idea to run this
    print("Expected Number of Solutions : ", ((E - S + 1* (R - L + 1)) // M + 1)
    if sol_ex(A, M, L, R) == False:
        return []
    cur = S - 1
    ans = []
    num_sol = 0
    while cur <= E:
        NL = (L - A * (cur + 1)) % M
        NR = (R - A * (cur + 1)) % M
        if NL > NR:
            cur += 1
        else:
            val = optf(A, M, NL, NR)
            cur += 1 + val
        if cur <= E:
            ans.append(cur)
            # remove assert for performance if needed
            assert is_inside(L, R, M, (A * cur) % M)
            num_sol += 1
    print("Actual Number of Solutions : ", num_sol)
    return ans
 
q1 = (p1 - 1// 2
q2 = (p2 - 1// 2
 
det_v = abs(s2 * r1 - s1 * r2)
 
md_2 = s1 * q2 
md_1 = s2 * q1 
 
= (h(m) * (s2 - s1)) % det_v
 
print(z % 2# 0
print(md_1 % 2# 1
print(md_2 % 2# 0
 
md_2 //= 2
det_v //= 2 
//= 2
 
= (md_1 * inverse(md_2, det_v)) % (det_v)
= (z * inverse(md_2, det_v)) % det_v 
 
BOUND = 1 << 515
 
print(det_v)
 
pepega = solve(A, det_v, det_v -BOUND - B, det_v + BOUND - B, -BOUND, BOUND)
 
print(len(pepega))
 
c_1 = int(pepega[0* 2)
c_2 = int((md_1 * (c_1 // 2+ z) * inverse(md_2, det_v) % det_v)
if c_1 > BOUND:
    c_1 -= det_v
if c_2 > BOUND:
    c_2 -= det_v
 
print(abs(c_1) < BOUND)
print(abs(c_2) < BOUND)
 
LHS = s2 * h(m) - s1 * h(m) + s2 * c_1 * q1 - s1 * c_2 * q2 
RHS = s1 * r2 - s2 * r1 
= LHS // RHS 
print(LHS % RHS)
print(long_to_bytes(x))
 
cs

 

 

Corrupted

We are given a broken PEM file that looks like the following.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAn+8Rj11c2JOgyf6s1Hiiwt553hw9+oGcd1EGo8H5tJOEiUnP
NixaIGMK1O7CU7+IEe43PJcGPPkCti2kz5qAXAyXXBMAlHF46spmQaQFpVRRVMZD
1yInh0QXEjgBBFZKaH3VLh9FpCKYpfqij+OlphoSHlfc7l2Wfct40TDFg13WdpVB
BseCEmaY/b+kxwdfVe7Dzt8kd2ASPuNbOqKvv8ijTgiqpsX5uinjvr/3/srINm8X
xpANqO/eSXP8kO4abOJtyfg2bWvO9QvQRaUIjnYioAkyiqcttbzGIekCfktlA+Rn
JLL19tEG43hubOZAwqGDxvXfKEKx9E2Yx4Da/wIDAQA?AoI?????8S??Om/???xN
3c??0?/G?OO?aQWQB??ECCi??KD?w??2mFc??pTM?r?rX??X+XFW??Rtw?o?d????ZQ?yp?mczG?q2?0O???1o3?Jt?8?+00s?SY+??MG??7d??7k??o?????ci?K??????wK??Y??gqV????9????YA?Hh5T????ICP+?3HTU?l???m0y?6??2???b2x???????+7??T????????n?7????b?P??iL?/???tq???5jLuy??lX?d?ZEO?7???ld???g
?r?rK??IYA???0???zYCIZt2S???cP??W????f???l5?3c+??UkJr4E?QH??PiiD
WLB???f5A?G?A???????????u???3?K???????I???S?????????J?p?3?N?W???
????r???????8???o???m?????8?s???1?4?l?T?3?j?y?6?F?c?g?3?A?8?S?1?
X?o?D?C?+?7?F?V?U?1?f?K?a?F?7?S?b?V?/?v?5?1?V?A?5?G?y?X?AoGB?L?i
?2?C?t?W?s?Z?h?L?t?3?r?d?M?s?U?E?L?P?n?2?U?G?M?g?D?u?E?s?a?h?K?m
?9?/?n?o?J?8?e?9?9?k?N?2?l?T?8?k?b?e?j?n?Q?u?z?z?e?A?S?6?0?w?5?0
?B?V?i?s?R?W?6?Y?6?u?l?s?G?c?Q?2?Q?w?U?l??GA??V?f???kVYfl???WyY?
3J?2fF?h/???UqfpeO???o?k?9kF??a8L?V?w??????J??9?iP????D???JSx??g??IUC0??t7???I??c??????eh/No?????y8???0?E+??1?JC?Oj??HFy??2T?1nV??HH?+???+??s?L?o??K?zc?????BhB2A?????E??b???e?f??KruaZ??u?tp?Tq?c?t?????iQ1qS??h??m?S?/????FDu3i?p???S??Q?o??0s?e0?n?Hv??C?CnM?/Dw
m9?????uC?Ktm????D?e????h7?A??V??O??5/XsY??Y?A???????q?y?gk?Pbq?
????MQK?gQ??SQ?????ERjLp?N??A??P?So?TPE??WWG???lK?Q????o?aztnUT?
eKe4+h0?VkuB?b?v?7ge?nK1??Jy7?y??9??????BP??gG?kKK?y?Z???yES4i??
?Uhc?p????c4ln?m?r???P??C?8?X?d??TP??k??B?dwjN7??ui?K????????-?N? ?S? ?RI?A?? KE?-???-
cs

 

We need to recover the full PEM key. The solution is really hands-on, and it needs some grinding.

The PEM decoding algorithm is in pycryptodome - basically, it's just a DER decoding. So how does DER decoding work?

 

By following the DER implementation in pycryptodome alongside with some debugging, it's basically as follows.

  • 1 byte is consumed as the octet. Not sure what this does. 
  • Then, 1 byte is consumed as the length $l$. If $l \le 127$, then this is the final length.
  • If $l \ge 128$, then the next $l \pmod{128}$ bytes in big endian represent the final length.
  • The "final length" bytes worth of data, in big endian, is the pushed data. 

However, the very first "final length" is actually the full length, so this one should be skipped. 

 

Also, we quickly note that 

  • By comparing lengths, it can be seen that this file is based on 2048 bit RSA. 
  • In this case, the PEM file has a linebreak every 64 characters. Based on this, we can remove some "?" to linebreaks.
  • The first step is to base64 decode the lines between the first and the last ones. By randomly selecting one of 64 choices for each "?" and decoding it multple times, we can figure out which bits in the decoded data are known, and which bits in the decoded data are the ones from the "?"s. This is useful when patching important "?"s so that the length informations makes sense.

The main part is to patch everything so that the length informations makes sense. Since the datasets are $n, e, d, p, q, d_p, d_q, u$ in order, just try every patch that makes sense and does not affect the bits that are actually known. Decoding an actual 2048 bit RSA PEM helps.

 

In the end, we'll have the full $n, e$ and some bits of $p, q, d_p, d_q, u$. However, $u$ is not needed here. 

The rest is relatively well-known - you set up an equation $$pq = n, \quad ed_p = k_p(p-1) + 1, \quad ed_q = k_q(q-1) + 1$$ and solve this equation modulo powers of 2 iteratively. To do so, you need to fix $k_p, k_q$ - it turns out that there are $\mathcal{O}(e)$ possible pairs of $(k_p, k_q)$, so you can try out all candidates. For example, see [HS09]. There is also my challenge "RSA Hex Permutation". 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from sage.all import *
from Crypto.Util.number import inverse, long_to_bytes, bytes_to_long 
from hashlib import sha256 
from base64 import b64decode
import string 
import random as rand 
 
raw_pem = open("meme.pem""rb").read()
 
 
ST = b"-----BEGIN RSA PRIVATE KEY-----\n"
EN = b"\n-----END RSA PRIVATE KEY-----"
raw_pem = raw_pem[len(ST): -len(EN)]
print(raw_pem)
 
meme = (string.ascii_uppercase + string.ascii_lowercase + string.digits + "+/=\n").encode()
 
print(len(raw_pem))
 
 
def gen_copy(lmao):
    ret = b""
    for l in lmao:
        ret += bytes([l])
    for i in range(len(ret)):
        if ret[i] not in meme:
            sel = rand.randint(063)
            ret = ret[:i] + bytes([meme[sel]]) + ret[i + 1: ]
    return ret
 
recovered = b64decode(gen_copy(raw_pem))
 
print(len(recovered))
 
EQ = [1* (len(recovered) * 8)
 
for trial in range(50):
    raw_pem_new = b64decode(gen_copy(raw_pem))
    for j in range(len(recovered)):
        for k in range(8):
            if ((recovered[j] >> (7 - k)) & 1!= ((raw_pem_new[j] >> (7 - k)) & 1):
                EQ[8 * j + k] = 0
 
def get_eq(l, r):
    print("EQ", l, r, EQ[8 * l : 8 * r])
 
def patch(org, l, r, patch):
    return org[:l] + bytes(patch) + org[r:]
 
# patch
 
# for e
recovered = patch(recovered, 272273, [1])
# for d length
recovered = patch(recovered, 275277, [11]) # either [1, 0] or [1, 1]
# for p length
recovered = patch(recovered, 535537, [129129])
# for d mod p - 1
recovered = patch(recovered, 799800, [129])
# for d mod q - 1
recovered = patch(recovered, 930932, [129128])
# for u
recovered = patch(recovered, 10611062, [129])
 
cur = 0
vals = []
for i in range(10):
    print("START", i, cur)
    print("octet", cur, recovered[cur])
    cur += 1
    l = recovered[cur]
    print("[+] length heat check", l, cur)
    get_eq(cur, cur + 1)
    cur += 1
    if l > 127:
        tt = l & 0x7F
        real_l = bytes_to_long(recovered[cur:cur+tt])
        print("[+] real length")
        get_eq(cur, cur + tt)
        print(recovered[cur:cur+tt])
        print(real_l, "at range", cur, cur+tt)
        cur += tt
    else:
        real_l = l
    if i == 0:
        continue 
    print("[+] data", recovered[cur:cur+real_l])
    get_eq(cur, cur + real_l)
    val = bytes_to_long(recovered[cur:cur+real_l])
    print("[+] appended", val)
    vals.append(val)
    cur += real_l 
 
print(vals)
 
cs

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
= data[1]
= data[2]
d_val = data[3]
p_val = data[4]
q_val = data[5]
d_p_val = data[6]
d_q_val = data[7]
 
k_p, k_q = 00
 
import sys
sys.setrecursionlimit(10 ** 6)
 
 
def work(ps, qs, dps, dqs, cur):
    if cur == 1028:
        return
    
    nxtps = []
    nxtqs = []
    nxtdps = []
    nxtdqs = []
 
    if cur > 950 and len(ps) == 1:
        # print(ps, cur)
        if n % ps[0== 0:
            print("WOW!!!")
    if len(ps) == 0:
        return
 
    for p, q, dp, dq in zip(ps, qs, dps, dqs):
        if cur >= 1000:
            if n % p == 0:
                print(p)
                print(n // p)
 
        if cur == 1028:
            continue
        
        for pi in range(2):
            if p_conf[-1-cur] == 1 and ((p_val >> cur) & 1!= pi:
                continue
            new_p = p + (pi << cur)
            for qi in range(2):
                if q_conf[-1-cur] == 1 and ((q_val >> cur) & 1!= qi:
                    continue
                new_q = q + (qi << cur)
                for dpi in range(2):
                    if d_p_conf[- 1 - cur] == 1 and ((d_p_val >> cur) & 1!= dpi:
                        continue 
                    new_dp = dp + (dpi << cur)
                    for dqi in range(2):
                        if d_q_conf[-1 - cur] == 1 and ((d_q_val >> cur) & 1!= dqi:
                            continue 
                        new_dq = dq + (dqi << cur)
                        A = abs(n - new_p * new_q)
                        B = abs(e * new_dp - k_p * (new_p - 1- 1)
                        C = abs(e * new_dq - k_q * (new_q - 1- 1)
                        if ((A >> cur) & 1!= 0:
                            continue
                        if ((B >> cur) & 1!= 0:
                            continue 
                        if ((C >> cur) & 1!= 0:
                            continue
                        nxtps.append(new_p)
                        nxtqs.append(new_q)
                        nxtdps.append(new_dp)
                        nxtdqs.append(new_dq)
 
    work(nxtps, nxtqs, nxtdps, nxtdqs, cur + 1)
 
from tqdm import tqdm 
 
for idx in tqdm(range(1, e)):
    if (idx * (n - 1+ 1) % e == 0:
        continue 
    k_p = idx 
    k_q = ((1 - k_p) * inverse(k_p * n - k_p + 1, e)) % e
    work([0], [0], [0], [0], 0)
cs

 

 

SusCipher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
import hashlib
import os
import signal
 
 
class SusCipher:
    S = [
        43,  8575348391561,
         74433,  91941,  314,
        4251,  6,  249285531,
         0,  430,  159503547,
        2516372710542658,
        6213182221241220,
        293823326034,  511,
        4563404652361756
    ]
 
    P = [
        21,  823,  6,  715,
        221319162528,
        31323436,  339,
        292624,  14335,
        451247171411,
        273741384020,
         2,  0,  5,  44218,
        44304633,  910
    ]
 
    ROUND = 3
    BLOCK_NUM = 8
    MASK = (1 << (6 * BLOCK_NUM)) - 1
 
    @classmethod
    def _divide(cls, v: int-> list[int]:
        l: list[int= []
        for _ in range(cls.BLOCK_NUM):
            l.append(v & 0b111111)
            v >>= 6
        return l[::-1]
 
    @staticmethod
    def _combine(block: list[int]) -> int:
        res = 0
        for v in block:
            res <<= 6
            res |= v
        return res
 
    @classmethod
    def _sub(cls, block: list[int]) -> list[int]:
        return [cls.S[v] for v in block]
 
    @classmethod
    def _perm(cls, block: list[int]) -> list[int]:
        bits = ""
        for b in block:
            bits += f"{b:06b}"
 
        buf = ["_" for _ in range(6 * cls.BLOCK_NUM)]
        for i in range(6 * cls.BLOCK_NUM):
            buf[cls.P[i]] = bits[i]
 
        permd = "".join(buf)
        return [int(permd[i : i + 6], 2for i in range(06 * cls.BLOCK_NUM, 6)]
 
    @staticmethod
    def _xor(a: list[int], b: list[int]) -> list[int]:
        return [x ^ y for x, y in zip(a, b)]
 
    def __init__(self, key: int):
        assert 0 <= key <= self.MASK
 
        keys = [key]
        for _ in range(self.ROUND):
            v = hashlib.sha256(str(keys[-1]).encode()).digest()
            v = int.from_bytes(v, "big"& self.MASK
            keys.append(v)
 
        self.subkeys = [self._divide(k) for k in keys]
 
    def encrypt(self, inp: int-> int:
        block = self._divide(inp)
 
        block = self._xor(block, self.subkeys[0])
        for r in range(self.ROUND):
            block = self._sub(block)
            block = self._perm(block)
            block = self._xor(block, self.subkeys[r + 1])
 
        return self._combine(block)
 
    # TODO: Implement decryption
    def decrypt(self, inp: int-> int:
        raise NotImplementedError()
 
 
def handler(_signum, _frame):
    print("Time out!")
    exit(0)
 
 
def main():
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(300)
    key = int.from_bytes(os.urandom(6), "big")
 
    cipher = SusCipher(key)
 
    while True:
        inp = input("> ")
 
        try:
            l = [int(v.strip()) for v in inp.split(",")]
        except ValueError:
            print("Wrong input!")
            exit(0)
 
        if len(l) > 0x100:
            print("Long input!")
            exit(0)
 
        if len(l) == 1 and l[0== key:
            with open('flag''r'as f:
                print(f.read())
 
        print(", ".join(str(cipher.encrypt(v)) for v in l))
 
 
if __name__ == "__main__":
    main()
 
cs

 

This is a 3-round block cipher, and a hint is given - use differential cryptanalysis. 

 

Let's find some easy differentials. We collect the $((i \oplus j)[u], (S_i \oplus S_j)[v])$ and see how it behave - and it turns out that the bias is noticeable. We can keep track of an array $pos$, where $pos_i$ denotes the bit location where the original $i$th bit is the most correlated to it. 

 

In the current state, there are 3 S-box applications, so the bias will decrease over those 3 S-boxes. However, if we know a key chunk, for example, the first 6 bits, then the first key addition and the S-box application can be computed directly, so in reality we are only going through 2 S-boxes. Therefore, the correlation between the state just after the first S-box and the encrypted state will be much greater.

We can now brute force all 64 possibilities for all 8 key chunks. 20K random encryptions are enough to reliably find the key.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from pwn import * 
import os 
 
= [
        43,  8575348391561,
         74433,  91941,  314,
        4251,  6,  249285531,
         0,  430,  159503547,
        2516372710542658,
        6213182221241220,
        293823326034,  511,
        4563404652361756
    ]
 
= [
        21,  823,  6,  715,
        221319162528,
        31323436,  339,
        292624,  14335,
        451247171411,
        273741384020,
         2,  0,  5,  44218,
        44304633,  910
    ]
 
def get_bit(res, w):
    return (res >> (5 - w)) & 1
 
 
bits = [[[0000for _ in range(6)] for _ in range(6)]
 
for i in range(64):
    for j in range(64):
        # i ^ j => Si ^ Sj 
        for u in range(6):
            for v in range(6):
                t1 = get_bit(i ^ j, u)
                t2 = get_bit(S[i] ^ S[j], v)
                bits[u][v][2 * t1 + t2] += 1
 
 
mx = 0
for i in range(6):
    print("[+]", i)
    mx_j = 0
    for j in range(6):
        mx_j = max(mx_j, bits[i][j][0])
    print(mx_j)
    for j in range(6):
        if bits[i][j][0== mx_j:
            print(j)
 
 
sub_loc = [501105]
 
 
def track_sub(pos):
    ret = [0* 48
    for i in range(48):
        loc = pos[i] // 6
        md = pos[i] % 6
        ret[i] = 6 * loc + sub_loc[md]
    return ret
 
def track_perm(pos):
    ret = [0* 48
    for i in range(48):
        ret[i] = P[pos[i]]
    return ret
 
full_enc = [i for i in range(48)]
for i in range(3):
    full_enc = track_sub(full_enc)
    full_enc = track_perm(full_enc)
 
part_enc = [i for i in range(48)]
part_enc = track_perm(part_enc)
for i in range(2):
    part_enc = track_sub(part_enc)
    part_enc = track_perm(part_enc)
 
plaintexts = [int.from_bytes(os.urandom(6), "big"for _ in range(256 * 80)]
 
= []
 
for i in range(80):
    st = ""
    for j in range(256):
        st += str(plaintexts[256 * i + j])
        if j != 255:
            st += ","
    l.append(st)
 
 
conn = remote("suscipher.chal.ctf.acsc.asia"13579)
 
conn.sendlines(l)
 
results = conn.recvlines(80)
 
enc = []
 
for i in range(80):
    encs = results[i][2:].split(b",")
    assert len(encs) == 256
    for j in range(256):
        enc.append(int(encs[j]))
 
 
key = 0
 
for i in range(8):
    res = [[0 for _ in range(6)] for _ in range(64)]
    fin = [0* 64
 
    dat_eq = 0
    tot = 0
    
    for idx in range(len(enc)):
        plaintext = plaintexts[idx]
        ciphertext = enc[idx]
 
        ptxt_block = (plaintext >> (6 * (7 - i))) & 63
        for loc in range(6):
            for k in range(64):
                bit_org = (S[k ^ ptxt_block] >> (5 - loc)) & 1
                bit_res = (ciphertext >> (47 - part_enc[6 * i + loc])) & 1
                if bit_org == bit_res:
                    res[k][loc] += 1
 
    for idx in range(64):
        for u in range(6):
            fin[idx] += abs(res[idx][u] - len(enc) // 2)
 
    v = 0
    for idx in range(64):
        v = max(v, fin[idx])
    
    ans = 0
    for idx in range(64):
        if v == fin[idx]:
            ans = idx
 
    key += (ans << (6 * (7 - i)))
 
conn.sendline(str(key).encode())
print(conn.recvline())
 
 
 
 
cs

 

 

Serverless

Basically, the encryption system works as follows. 

  • Select one prime $p$ from a fixed array
  • Select one prime $q$ from a fixed array 
  • Select $0 \le s \le 4$ and choose $e = 2^{2^s} + 1$
  • Textbook RSA encrypt data, little-endian it, append some values ($s$ and the indexes for $p, q$)
  • XOR the password, reverse the data, then base64 encode it 

As we know the fixed array of primes and the password, the decryption is easy. 

 

'CTF' 카테고리의 다른 글

Paradigm CTF 2023 2nd Place  (0) 2023.10.31
CODEGATE 2023 Finals - The Leakers (1 solve)  (1) 2023.08.25
HackTM CTF Writeup  (0) 2023.02.22
BlackHat MEA Finals  (0) 2022.11.21
CODEGATE 2022 Finals: Look It Up  (0) 2022.11.09