Finished 8th in ACSC Quals 2023. Solved all cryptography challenges and warmup rev. I had a alumni meetup with NEXON Youth Programming Challenge award winners, so I went home at like 9PM with a beer and sangria inside me.

(Here, sangria is not the recent folding scheme for PLONK from Geometry Research)

 

Kinda gave up on the whole ACSC thing until it was around midnight, then I saw that I could make it if I solve all cryptography challenges.

So I hurried to do so and finished around 5:30AM, then solved the warmup rev. Turns out that this was a good idea.

 

Merkle-Hellman

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
import random
import binascii
 
def egcd(a, b):
    if a == 0:
        return (b, 01)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)
 
def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m
 
def gcd(a, b): 
    if a == 0
        return b 
    return gcd(b % a, a) 
 
flag = open("flag.txt","rb").read()
# Generate superincreasing sequence
= [random.randint(1,256)]
= w[0]
for i in range(6):
    num = random.randint(s+1,s+256)
    w.append(num)
    s += num
 
# Generate private key
total = sum(w)
= random.randint(total+1,total+256)
= 0
while gcd(r,q) != 1:
    r = random.randint(100, q)
 
# Calculate public key
= []
for i in w:
    b.append((i * r) % q)
 
# Encrypting
= []
for f in flag:
    s = 0
    for i in range(7):
        if f & (64>>i):
            s += b[i]
    c.append(s)
 
print(f"Public Key = {b}")
print(f"Private Key = {w,q}")
print(f"Ciphertext = {c}")
 
# Output:
# Public Key = [7352, 2356, 7579, 19235, 1944, 14029, 1084]
# Private Key = ([184, 332, 713, 1255, 2688, 5243, 10448], 20910)
# Ciphertext = [8436, 22465, 30044, 22465, 51635, 10380, 11879, 50551, 35250, 51223, 14931, 25048, 7352, 50551, 37606, 39550]
cs

 

This is a knapsack-based cryptosystem, and we know practically everything here. Just decrypt it.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from sage.all import *
from Crypto.Util.number import inverse, long_to_bytes, bytes_to_long 
 
pk = [735223567579192351944140291084]
sk = [18433271312552688524310448]
= 20910
ctxt = [843622465300442246551635103801187950551352505122314931250487352505513760639550]
 
 
df = sk[2* inverse(pk[2], q)
 
res = b""
 
for c in ctxt:
    c = (c * df) % q 
    f = 0
    for i in range(6-1-1):
        if c >= sk[i]:
            f += (1 << (6 - i))
            c -= sk[i]
    res += bytes([f])
 
print(res)
 
cs

 

 

Check Number 63

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from Crypto.Util.number import *
import gmpy2
from flag import *
 
= open("output.txt","w")
 
f.write(f"n = {n}\n")
 
while e < 66173:
  d = inverse(e,(p-1)*(q-1))
  check_number = (e*- 1// ( (p-1)*(q-1) )
  f.write(f"{e}:{check_number}\n")
  assert (e*- 1) % ( (p-1)*(q-1) ) == 0
  e = gmpy2.next_prime(e)
  
f.close()
 
 
cs

 

We have $n$ alongside 63 pairs of $(e, k)$ where $$ed = k \phi(n) + 1$$ The goal is to factor $n$. First, the central idea is to note that $$k \phi(n) + 1 \equiv 0 \pmod{e}$$ so we recover $$\phi(n) \equiv - k^{-1} \pmod{e}$$ Combined with CRT, this gives us $$\phi(n) \pmod{ \prod e}$$ which is around $63 * 16 = 1008$ bits of information. Meanwhile, we since $$\phi(n) = n + 1 - (p + q)$$ we have a 1025 bit range for $\phi(n)$. We can easily brute-force all possible values of $\phi(n)$, then recover $p, q$ from $n, \phi(n)$. 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from sage.all import *
from Crypto.Util.number import inverse, long_to_bytes, bytes_to_long 
 
# ed = k phi(n) + 1
# k phi(n) + 1 == 0 mod e
# phi(n) == -k^-1 mod e
 
= # given number here
 
def iroot(v, r):
    return Integer(v).nth_root(r, truncate_mode=False)
 
tt = open("output.txt""r")
tt = tt.readlines()
 
vals = []
mods = []
 
for l in tt:
    wow = l.split(":")
    e = int(wow[0])
    res = int(wow[1])
    vals.append(int(e - inverse(res, e)))
    mods.append(e)
 
cc = int(prod(mods))
 
phi_cc = int(CRT(vals, mods))
 
for l in tt:
    wow = l.split(":")
    e = int(wow[0])
    res = int(wow[1])
    assert (phi_cc * res + 1) % e == 0
 
from tqdm import tqdm
 
for i in tqdm(range(1 << 20)):
    tot = int((n - phi_cc + 1) % cc)
    tot += i * cc
    phi = n - tot + 1
    assert phi % cc == phi_cc 
    try:
        p = (tot - iroot(tot * tot - 4 * n, 2)) // 2
        q = tot - p
        assert p < q 
        assert p * q == n
        print(p, q)
        from hashlib import sha512
        flag = "ACSC{" + sha512( f"{p}{q}".encode() ).hexdigest() + "}" 
        print(flag)
    except:
        pass
 
 
 
 
 
 
cs

 

 

Dual Signature Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from hashlib import sha256
from Crypto.Util.number import getPrime, isPrime, getRandomNBitInteger, inverse
 
 
flag = os.environ.get("FLAG""neko{cat_are_the_most_powerful_beings_in_fact}")
 
 
def h(m: bytes) -> int:
    return int(sha256(m).hexdigest(), 16)
 
 
def gen_prime():
    while True:
        q = getPrime(520)
        p = 2*+ 1
        if isPrime(p):
            return p, q
 
 
p1, q1 = gen_prime()
p2, q2 = gen_prime()
 
if q1 > q2:
    (p1, q1), (p2, q2) = (p2, q2), (p1, q1)
 
= int((os.urandom(512 // 8 - len(flag) - 1+ flag.encode()).hex(), 16)
= 4
y1 = pow(g, x, p1)
y2 = pow(g, x, p2)
 
 
def sign(m: bytes):
    z = h(m)
    k = getRandomNBitInteger(512)
    r1 = pow(g, k, p1)
    r2 = pow(g, k, p2)
    s1 = inverse(k, q1) * (z + r1*x) % q1
    s2 = inverse(k, q2) * (z + r2*x) % q2
 
    return (r1, s1), (r2, s2)
 
 
def verify(m: bytes, sig1, sig2):
    z = h(m)
    r1, s1 = sig1
    r2, s2 = sig2
 
    s1inv = inverse(s1, q1)
    s2inv = inverse(s2, q2)
    gk1 = pow(g, s1inv*z, p1) * pow(y1, s1inv*r1, p1) % p1
    gk2 = pow(g, s2inv*z, p2) * pow(y2, s2inv*r2, p2) % p2
 
    return r1 == gk1 and r2 == gk2
 
 
= b"omochi mochimochi mochimochi omochi"
sig1, sig2 = sign(m)
 
print(f"g = {g}")
print(f"p1, p2 = {p1}, {p2}")
print(f"y1, y2 = {y1}, {y2}")
 
print(f"m = {m}")
print(f"r1, s1 = {sig1}")
print(f"r2, s2 = {sig2}")
 
cs

 

I overcomplicated this problem way too much - the easier solution is combining the two signature schemes via CRT then LLL.

I tried some straightforward lattices without the CRT idea, but it didn't give me the answer. Here's my solution.

 

Start with the equations $$ks_1 = z + r_1x + c_1q_1, \quad ks_2 = z + r_2x + c_2q_2$$ where $c_1, c_2$ each have absolute values of at most something like $2^{515}$. We'll cancel out the $k$ in the equations to get $$s_2(z + r_1x + c_1q_1) = s_1(z+ r_2x + c_2q_2)$$ or $$(s_2r_1 - s_1r_2)x = (s_1 - s_2)z + (s_1q_2) c_2 - (s_2q_1) c_1$$ This gives us a system - we want $$-2^{515} \le c_1, c_2 \le 2^{515}$$ as well as the linear equation $$(s_1q_2) c_2 \equiv (s_2q_1) c_1 + (s_2 - s_1)z \pmod{s_2r_1 - s_1r_2}$$ While there are some GCD issues like $\gcd(s_1q_2, s_2r_1 - s_1r_2) = 2 \neq 1$, in essence this is the same type of problem as $$S \le c_1 \le E, \quad L \le Ax + B \bmod{M} \le R$$ which is the exact task that the special case of my CVP repository solves. After getting $c_1, c_2$, the rest is easy linear equation. 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
def ceil(n, m): # returns ceil(n/m)
    return (n + m - 1// m
 
def is_inside(L, R, M, val): # is L <= val <= R in mod M context?
    if L <= R:
        return L <= val <= R
    else:
        R += M
        if L <= val <= R:
            return True
        if L <= val + M <= R:
            return True 
        return False
 
def optf(A, M, L, R): # minimum nonnegative x s.t. L <= Ax mod M <= R
    if L == 0:
        return 0
    if 2 * A > M:
        L, R = R, L
        A, L, R = M - A, M - L, M - R
    cc_1 = ceil(L, A)
    if A * cc_1 <= R:
        return cc_1
    cc_2 = optf(A - M % A, A, L % A, R % A)
    return ceil(L + M * cc_2, A)
 
# check if L <= Ax (mod M) <= R has a solution
def sol_ex(A, M, L, R):
    if L == 0 or L > R:
        return True
    g = GCD(A, M)
    if (L - 1// g == R // g:
        return False
    return True
 
## find all solutions for L <= Ax mod M <= R, S <= x <= E:
def solve(A, M, L, R, S, E):
    # this is for estimate only : if very large, might be a bad idea to run this
    print("Expected Number of Solutions : ", ((E - S + 1* (R - L + 1)) // M + 1)
    if sol_ex(A, M, L, R) == False:
        return []
    cur = S - 1
    ans = []
    num_sol = 0
    while cur <= E:
        NL = (L - A * (cur + 1)) % M
        NR = (R - A * (cur + 1)) % M
        if NL > NR:
            cur += 1
        else:
            val = optf(A, M, NL, NR)
            cur += 1 + val
        if cur <= E:
            ans.append(cur)
            # remove assert for performance if needed
            assert is_inside(L, R, M, (A * cur) % M)
            num_sol += 1
    print("Actual Number of Solutions : ", num_sol)
    return ans
 
q1 = (p1 - 1// 2
q2 = (p2 - 1// 2
 
det_v = abs(s2 * r1 - s1 * r2)
 
md_2 = s1 * q2 
md_1 = s2 * q1 
 
= (h(m) * (s2 - s1)) % det_v
 
print(z % 2# 0
print(md_1 % 2# 1
print(md_2 % 2# 0
 
md_2 //= 2
det_v //= 2 
//= 2
 
= (md_1 * inverse(md_2, det_v)) % (det_v)
= (z * inverse(md_2, det_v)) % det_v 
 
BOUND = 1 << 515
 
print(det_v)
 
pepega = solve(A, det_v, det_v -BOUND - B, det_v + BOUND - B, -BOUND, BOUND)
 
print(len(pepega))
 
c_1 = int(pepega[0* 2)
c_2 = int((md_1 * (c_1 // 2+ z) * inverse(md_2, det_v) % det_v)
if c_1 > BOUND:
    c_1 -= det_v
if c_2 > BOUND:
    c_2 -= det_v
 
print(abs(c_1) < BOUND)
print(abs(c_2) < BOUND)
 
LHS = s2 * h(m) - s1 * h(m) + s2 * c_1 * q1 - s1 * c_2 * q2 
RHS = s1 * r2 - s2 * r1 
= LHS // RHS 
print(LHS % RHS)
print(long_to_bytes(x))
 
cs

 

 

Corrupted

We are given a broken PEM file that looks like the following.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAn+8Rj11c2JOgyf6s1Hiiwt553hw9+oGcd1EGo8H5tJOEiUnP
NixaIGMK1O7CU7+IEe43PJcGPPkCti2kz5qAXAyXXBMAlHF46spmQaQFpVRRVMZD
1yInh0QXEjgBBFZKaH3VLh9FpCKYpfqij+OlphoSHlfc7l2Wfct40TDFg13WdpVB
BseCEmaY/b+kxwdfVe7Dzt8kd2ASPuNbOqKvv8ijTgiqpsX5uinjvr/3/srINm8X
xpANqO/eSXP8kO4abOJtyfg2bWvO9QvQRaUIjnYioAkyiqcttbzGIekCfktlA+Rn
JLL19tEG43hubOZAwqGDxvXfKEKx9E2Yx4Da/wIDAQA?AoI?????8S??Om/???xN
3c??0?/G?OO?aQWQB??ECCi??KD?w??2mFc??pTM?r?rX??X+XFW??Rtw?o?d????ZQ?yp?mczG?q2?0O???1o3?Jt?8?+00s?SY+??MG??7d??7k??o?????ci?K??????wK??Y??gqV????9????YA?Hh5T????ICP+?3HTU?l???m0y?6??2???b2x???????+7??T????????n?7????b?P??iL?/???tq???5jLuy??lX?d?ZEO?7???ld???g
?r?rK??IYA???0???zYCIZt2S???cP??W????f???l5?3c+??UkJr4E?QH??PiiD
WLB???f5A?G?A???????????u???3?K???????I???S?????????J?p?3?N?W???
????r???????8???o???m?????8?s???1?4?l?T?3?j?y?6?F?c?g?3?A?8?S?1?
X?o?D?C?+?7?F?V?U?1?f?K?a?F?7?S?b?V?/?v?5?1?V?A?5?G?y?X?AoGB?L?i
?2?C?t?W?s?Z?h?L?t?3?r?d?M?s?U?E?L?P?n?2?U?G?M?g?D?u?E?s?a?h?K?m
?9?/?n?o?J?8?e?9?9?k?N?2?l?T?8?k?b?e?j?n?Q?u?z?z?e?A?S?6?0?w?5?0
?B?V?i?s?R?W?6?Y?6?u?l?s?G?c?Q?2?Q?w?U?l??GA??V?f???kVYfl???WyY?
3J?2fF?h/???UqfpeO???o?k?9kF??a8L?V?w??????J??9?iP????D???JSx??g??IUC0??t7???I??c??????eh/No?????y8???0?E+??1?JC?Oj??HFy??2T?1nV??HH?+???+??s?L?o??K?zc?????BhB2A?????E??b???e?f??KruaZ??u?tp?Tq?c?t?????iQ1qS??h??m?S?/????FDu3i?p???S??Q?o??0s?e0?n?Hv??C?CnM?/Dw
m9?????uC?Ktm????D?e????h7?A??V??O??5/XsY??Y?A???????q?y?gk?Pbq?
????MQK?gQ??SQ?????ERjLp?N??A??P?So?TPE??WWG???lK?Q????o?aztnUT?
eKe4+h0?VkuB?b?v?7ge?nK1??Jy7?y??9??????BP??gG?kKK?y?Z???yES4i??
?Uhc?p????c4ln?m?r???P??C?8?X?d??TP??k??B?dwjN7??ui?K????????-?N? ?S? ?RI?A?? KE?-???-
cs

 

We need to recover the full PEM key. The solution is really hands-on, and it needs some grinding.

The PEM decoding algorithm is in pycryptodome - basically, it's just a DER decoding. So how does DER decoding work?

 

By following the DER implementation in pycryptodome alongside with some debugging, it's basically as follows.

  • 1 byte is consumed as the octet. Not sure what this does. 
  • Then, 1 byte is consumed as the length $l$. If $l \le 127$, then this is the final length.
  • If $l \ge 128$, then the next $l \pmod{128}$ bytes in big endian represent the final length.
  • The "final length" bytes worth of data, in big endian, is the pushed data. 

However, the very first "final length" is actually the full length, so this one should be skipped. 

 

Also, we quickly note that 

  • By comparing lengths, it can be seen that this file is based on 2048 bit RSA. 
  • In this case, the PEM file has a linebreak every 64 characters. Based on this, we can remove some "?" to linebreaks.
  • The first step is to base64 decode the lines between the first and the last ones. By randomly selecting one of 64 choices for each "?" and decoding it multple times, we can figure out which bits in the decoded data are known, and which bits in the decoded data are the ones from the "?"s. This is useful when patching important "?"s so that the length informations makes sense.

The main part is to patch everything so that the length informations makes sense. Since the datasets are $n, e, d, p, q, d_p, d_q, u$ in order, just try every patch that makes sense and does not affect the bits that are actually known. Decoding an actual 2048 bit RSA PEM helps.

 

In the end, we'll have the full $n, e$ and some bits of $p, q, d_p, d_q, u$. However, $u$ is not needed here. 

The rest is relatively well-known - you set up an equation $$pq = n, \quad ed_p = k_p(p-1) + 1, \quad ed_q = k_q(q-1) + 1$$ and solve this equation modulo powers of 2 iteratively. To do so, you need to fix $k_p, k_q$ - it turns out that there are $\mathcal{O}(e)$ possible pairs of $(k_p, k_q)$, so you can try out all candidates. For example, see [HS09]. There is also my challenge "RSA Hex Permutation". 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from sage.all import *
from Crypto.Util.number import inverse, long_to_bytes, bytes_to_long 
from hashlib import sha256 
from base64 import b64decode
import string 
import random as rand 
 
raw_pem = open("meme.pem""rb").read()
 
 
ST = b"-----BEGIN RSA PRIVATE KEY-----\n"
EN = b"\n-----END RSA PRIVATE KEY-----"
raw_pem = raw_pem[len(ST): -len(EN)]
print(raw_pem)
 
meme = (string.ascii_uppercase + string.ascii_lowercase + string.digits + "+/=\n").encode()
 
print(len(raw_pem))
 
 
def gen_copy(lmao):
    ret = b""
    for l in lmao:
        ret += bytes([l])
    for i in range(len(ret)):
        if ret[i] not in meme:
            sel = rand.randint(063)
            ret = ret[:i] + bytes([meme[sel]]) + ret[i + 1: ]
    return ret
 
recovered = b64decode(gen_copy(raw_pem))
 
print(len(recovered))
 
EQ = [1* (len(recovered) * 8)
 
for trial in range(50):
    raw_pem_new = b64decode(gen_copy(raw_pem))
    for j in range(len(recovered)):
        for k in range(8):
            if ((recovered[j] >> (7 - k)) & 1!= ((raw_pem_new[j] >> (7 - k)) & 1):
                EQ[8 * j + k] = 0
 
def get_eq(l, r):
    print("EQ", l, r, EQ[8 * l : 8 * r])
 
def patch(org, l, r, patch):
    return org[:l] + bytes(patch) + org[r:]
 
# patch
 
# for e
recovered = patch(recovered, 272273, [1])
# for d length
recovered = patch(recovered, 275277, [11]) # either [1, 0] or [1, 1]
# for p length
recovered = patch(recovered, 535537, [129129])
# for d mod p - 1
recovered = patch(recovered, 799800, [129])
# for d mod q - 1
recovered = patch(recovered, 930932, [129128])
# for u
recovered = patch(recovered, 10611062, [129])
 
cur = 0
vals = []
for i in range(10):
    print("START", i, cur)
    print("octet", cur, recovered[cur])
    cur += 1
    l = recovered[cur]
    print("[+] length heat check", l, cur)
    get_eq(cur, cur + 1)
    cur += 1
    if l > 127:
        tt = l & 0x7F
        real_l = bytes_to_long(recovered[cur:cur+tt])
        print("[+] real length")
        get_eq(cur, cur + tt)
        print(recovered[cur:cur+tt])
        print(real_l, "at range", cur, cur+tt)
        cur += tt
    else:
        real_l = l
    if i == 0:
        continue 
    print("[+] data", recovered[cur:cur+real_l])
    get_eq(cur, cur + real_l)
    val = bytes_to_long(recovered[cur:cur+real_l])
    print("[+] appended", val)
    vals.append(val)
    cur += real_l 
 
print(vals)
 
cs

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
= data[1]
= data[2]
d_val = data[3]
p_val = data[4]
q_val = data[5]
d_p_val = data[6]
d_q_val = data[7]
 
k_p, k_q = 00
 
import sys
sys.setrecursionlimit(10 ** 6)
 
 
def work(ps, qs, dps, dqs, cur):
    if cur == 1028:
        return
    
    nxtps = []
    nxtqs = []
    nxtdps = []
    nxtdqs = []
 
    if cur > 950 and len(ps) == 1:
        # print(ps, cur)
        if n % ps[0== 0:
            print("WOW!!!")
    if len(ps) == 0:
        return
 
    for p, q, dp, dq in zip(ps, qs, dps, dqs):
        if cur >= 1000:
            if n % p == 0:
                print(p)
                print(n // p)
 
        if cur == 1028:
            continue
        
        for pi in range(2):
            if p_conf[-1-cur] == 1 and ((p_val >> cur) & 1!= pi:
                continue
            new_p = p + (pi << cur)
            for qi in range(2):
                if q_conf[-1-cur] == 1 and ((q_val >> cur) & 1!= qi:
                    continue
                new_q = q + (qi << cur)
                for dpi in range(2):
                    if d_p_conf[- 1 - cur] == 1 and ((d_p_val >> cur) & 1!= dpi:
                        continue 
                    new_dp = dp + (dpi << cur)
                    for dqi in range(2):
                        if d_q_conf[-1 - cur] == 1 and ((d_q_val >> cur) & 1!= dqi:
                            continue 
                        new_dq = dq + (dqi << cur)
                        A = abs(n - new_p * new_q)
                        B = abs(e * new_dp - k_p * (new_p - 1- 1)
                        C = abs(e * new_dq - k_q * (new_q - 1- 1)
                        if ((A >> cur) & 1!= 0:
                            continue
                        if ((B >> cur) & 1!= 0:
                            continue 
                        if ((C >> cur) & 1!= 0:
                            continue
                        nxtps.append(new_p)
                        nxtqs.append(new_q)
                        nxtdps.append(new_dp)
                        nxtdqs.append(new_dq)
 
    work(nxtps, nxtqs, nxtdps, nxtdqs, cur + 1)
 
from tqdm import tqdm 
 
for idx in tqdm(range(1, e)):
    if (idx * (n - 1+ 1) % e == 0:
        continue 
    k_p = idx 
    k_q = ((1 - k_p) * inverse(k_p * n - k_p + 1, e)) % e
    work([0], [0], [0], [0], 0)
cs

 

 

SusCipher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
import hashlib
import os
import signal
 
 
class SusCipher:
    S = [
        43,  8575348391561,
         74433,  91941,  314,
        4251,  6,  249285531,
         0,  430,  159503547,
        2516372710542658,
        6213182221241220,
        293823326034,  511,
        4563404652361756
    ]
 
    P = [
        21,  823,  6,  715,
        221319162528,
        31323436,  339,
        292624,  14335,
        451247171411,
        273741384020,
         2,  0,  5,  44218,
        44304633,  910
    ]
 
    ROUND = 3
    BLOCK_NUM = 8
    MASK = (1 << (6 * BLOCK_NUM)) - 1
 
    @classmethod
    def _divide(cls, v: int-> list[int]:
        l: list[int= []
        for _ in range(cls.BLOCK_NUM):
            l.append(v & 0b111111)
            v >>= 6
        return l[::-1]
 
    @staticmethod
    def _combine(block: list[int]) -> int:
        res = 0
        for v in block:
            res <<= 6
            res |= v
        return res
 
    @classmethod
    def _sub(cls, block: list[int]) -> list[int]:
        return [cls.S[v] for v in block]
 
    @classmethod
    def _perm(cls, block: list[int]) -> list[int]:
        bits = ""
        for b in block:
            bits += f"{b:06b}"
 
        buf = ["_" for _ in range(6 * cls.BLOCK_NUM)]
        for i in range(6 * cls.BLOCK_NUM):
            buf[cls.P[i]] = bits[i]
 
        permd = "".join(buf)
        return [int(permd[i : i + 6], 2for i in range(06 * cls.BLOCK_NUM, 6)]
 
    @staticmethod
    def _xor(a: list[int], b: list[int]) -> list[int]:
        return [x ^ y for x, y in zip(a, b)]
 
    def __init__(self, key: int):
        assert 0 <= key <= self.MASK
 
        keys = [key]
        for _ in range(self.ROUND):
            v = hashlib.sha256(str(keys[-1]).encode()).digest()
            v = int.from_bytes(v, "big"& self.MASK
            keys.append(v)
 
        self.subkeys = [self._divide(k) for k in keys]
 
    def encrypt(self, inp: int-> int:
        block = self._divide(inp)
 
        block = self._xor(block, self.subkeys[0])
        for r in range(self.ROUND):
            block = self._sub(block)
            block = self._perm(block)
            block = self._xor(block, self.subkeys[r + 1])
 
        return self._combine(block)
 
    # TODO: Implement decryption
    def decrypt(self, inp: int-> int:
        raise NotImplementedError()
 
 
def handler(_signum, _frame):
    print("Time out!")
    exit(0)
 
 
def main():
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(300)
    key = int.from_bytes(os.urandom(6), "big")
 
    cipher = SusCipher(key)
 
    while True:
        inp = input("> ")
 
        try:
            l = [int(v.strip()) for v in inp.split(",")]
        except ValueError:
            print("Wrong input!")
            exit(0)
 
        if len(l) > 0x100:
            print("Long input!")
            exit(0)
 
        if len(l) == 1 and l[0== key:
            with open('flag''r'as f:
                print(f.read())
 
        print(", ".join(str(cipher.encrypt(v)) for v in l))
 
 
if __name__ == "__main__":
    main()
 
cs

 

This is a 3-round block cipher, and a hint is given - use differential cryptanalysis. 

 

Let's find some easy differentials. We collect the $((i \oplus j)[u], (S_i \oplus S_j)[v])$ and see how it behave - and it turns out that the bias is noticeable. We can keep track of an array $pos$, where $pos_i$ denotes the bit location where the original $i$th bit is the most correlated to it. 

 

In the current state, there are 3 S-box applications, so the bias will decrease over those 3 S-boxes. However, if we know a key chunk, for example, the first 6 bits, then the first key addition and the S-box application can be computed directly, so in reality we are only going through 2 S-boxes. Therefore, the correlation between the state just after the first S-box and the encrypted state will be much greater.

We can now brute force all 64 possibilities for all 8 key chunks. 20K random encryptions are enough to reliably find the key.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from pwn import * 
import os 
 
= [
        43,  8575348391561,
         74433,  91941,  314,
        4251,  6,  249285531,
         0,  430,  159503547,
        2516372710542658,
        6213182221241220,
        293823326034,  511,
        4563404652361756
    ]
 
= [
        21,  823,  6,  715,
        221319162528,
        31323436,  339,
        292624,  14335,
        451247171411,
        273741384020,
         2,  0,  5,  44218,
        44304633,  910
    ]
 
def get_bit(res, w):
    return (res >> (5 - w)) & 1
 
 
bits = [[[0000for _ in range(6)] for _ in range(6)]
 
for i in range(64):
    for j in range(64):
        # i ^ j => Si ^ Sj 
        for u in range(6):
            for v in range(6):
                t1 = get_bit(i ^ j, u)
                t2 = get_bit(S[i] ^ S[j], v)
                bits[u][v][2 * t1 + t2] += 1
 
 
mx = 0
for i in range(6):
    print("[+]", i)
    mx_j = 0
    for j in range(6):
        mx_j = max(mx_j, bits[i][j][0])
    print(mx_j)
    for j in range(6):
        if bits[i][j][0== mx_j:
            print(j)
 
 
sub_loc = [501105]
 
 
def track_sub(pos):
    ret = [0* 48
    for i in range(48):
        loc = pos[i] // 6
        md = pos[i] % 6
        ret[i] = 6 * loc + sub_loc[md]
    return ret
 
def track_perm(pos):
    ret = [0* 48
    for i in range(48):
        ret[i] = P[pos[i]]
    return ret
 
full_enc = [i for i in range(48)]
for i in range(3):
    full_enc = track_sub(full_enc)
    full_enc = track_perm(full_enc)
 
part_enc = [i for i in range(48)]
part_enc = track_perm(part_enc)
for i in range(2):
    part_enc = track_sub(part_enc)
    part_enc = track_perm(part_enc)
 
plaintexts = [int.from_bytes(os.urandom(6), "big"for _ in range(256 * 80)]
 
= []
 
for i in range(80):
    st = ""
    for j in range(256):
        st += str(plaintexts[256 * i + j])
        if j != 255:
            st += ","
    l.append(st)
 
 
conn = remote("suscipher.chal.ctf.acsc.asia"13579)
 
conn.sendlines(l)
 
results = conn.recvlines(80)
 
enc = []
 
for i in range(80):
    encs = results[i][2:].split(b",")
    assert len(encs) == 256
    for j in range(256):
        enc.append(int(encs[j]))
 
 
key = 0
 
for i in range(8):
    res = [[0 for _ in range(6)] for _ in range(64)]
    fin = [0* 64
 
    dat_eq = 0
    tot = 0
    
    for idx in range(len(enc)):
        plaintext = plaintexts[idx]
        ciphertext = enc[idx]
 
        ptxt_block = (plaintext >> (6 * (7 - i))) & 63
        for loc in range(6):
            for k in range(64):
                bit_org = (S[k ^ ptxt_block] >> (5 - loc)) & 1
                bit_res = (ciphertext >> (47 - part_enc[6 * i + loc])) & 1
                if bit_org == bit_res:
                    res[k][loc] += 1
 
    for idx in range(64):
        for u in range(6):
            fin[idx] += abs(res[idx][u] - len(enc) // 2)
 
    v = 0
    for idx in range(64):
        v = max(v, fin[idx])
    
    ans = 0
    for idx in range(64):
        if v == fin[idx]:
            ans = idx
 
    key += (ans << (6 * (7 - i)))
 
conn.sendline(str(key).encode())
print(conn.recvline())
 
 
 
 
cs

 

 

Serverless

Basically, the encryption system works as follows. 

  • Select one prime $p$ from a fixed array
  • Select one prime $q$ from a fixed array 
  • Select $0 \le s \le 4$ and choose $e = 2^{2^s} + 1$
  • Textbook RSA encrypt data, little-endian it, append some values ($s$ and the indexes for $p, q$)
  • XOR the password, reverse the data, then base64 encode it 

As we know the fixed array of primes and the password, the decryption is easy. 

 

'CTF' 카테고리의 다른 글

Paradigm CTF 2023 2nd Place  (0) 2023.10.31
CODEGATE 2023 Finals - The Leakers (1 solve)  (1) 2023.08.25
HackTM CTF Writeup  (0) 2023.02.22
BlackHat MEA Finals  (0) 2022.11.21
CODEGATE 2022 Finals: Look It Up  (0) 2022.11.09