PBCTF 2nd place, Super Guesser (apparently Super Guessers)

Solved : Goodhash, Yet Another RSA, Yet Another PRNG, Seed Me

 

Goodhash

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3
 
from Crypto.Cipher import AES
from Crypto.Util.number import *
from flag import flag
import json
import os
import string
 
ACCEPTABLE = string.ascii_letters + string.digits + string.punctuation + " "
 
 
class GoodHash:
    def __init__(self, v=b""):
        self.key = b"goodhashGOODHASH"
        self.buf = v
 
    def update(self, v):
        self.buf += v
 
    def digest(self):
        cipher = AES.new(self.key, AES.MODE_GCM, nonce=self.buf)
        enc, tag = cipher.encrypt_and_digest(b"\0" * 32)
        return enc + tag
 
    def hexdigest(self):
        return self.digest().hex()
 
 
if __name__ == "__main__":
    token = json.dumps({"token": os.urandom(16).hex(), "admin"False})
    token_hash = GoodHash(token.encode()).hexdigest()
    print(f"Body: {token}")
    print(f"Hash: {token_hash}")
 
    inp = input("> ")
    if len(inp) > 64 or any(v not in ACCEPTABLE for v in inp):
        print("Invalid input :(")
        exit(0)
 
    inp_hash = GoodHash(inp.encode()).hexdigest()
 
    if token_hash == inp_hash:
        try:
            token = json.loads(inp)
            if token["admin"== True:
                print("Wow, how did you find a collision?")
                print(f"Here's the flag: {flag}")
            else:
                print("Nice try.")
                print("Now you need to set the admin value to True")
        except:
            print("Invalid input :(")
    else:
        print("Invalid input :(")
 
cs

 

This is a hash collision challenge. We read the code to find the following two facts.

  • The hash function is computed by sending the input as the nonce, and encrypting 32 zero bytes with AES-GCM with a known key. 
  • Our collision needs to be in a JSON format, with "admin" being set to True.

Usually, in AES-GCM, the nonce is 12 bytes. However, we may send a bytearray with larger length, which suggests that there will be some logic that compresses our bytearray to be 12 bytes. With this in mind, we look at the pycryptodome library code.

 

https://github.com/Legrandin/pycryptodome/blob/master/lib/Crypto/Cipher/_mode_gcm.py

 

The important part begins at line 229. If the length of the input nonce is not 12, we compute the GHASH of $$ \text{pad}(m) || 0^{64} || \text{len}(m)$$ where $\text{pad}(m)$ is $m$ padded to be a bytearray of length multiple of 16 by appending zero bytes appropriately.

To compute the GHASH, we use the finite field $GF(2^{128})$ and denote $$H = \text{Enc}_{key}(0^{128})$$ and apply $$\text{GHASH}(X_1 || X_2 || \cdots || X_n) = X_1 H^n + X_2 H^{n-1} + \cdots + X_n H$$ Since we already know $H$, we can control the GHASH of a bytearray even if we select all but one block arbitrarily. In other words, we can choose $n-1$ blocks in any way we want, and we can fully control the GHASH by carefully selecting the value of the remaining one block. 

 

Solution 1

 

The above fact gives us one immediate idea. We can attempt to construct a bytearray that 

  • Has length 61, which is the length of the original JSON, which is there to force same GHASH for the actual final block
  • Can be converted into a JSON structure, with "admin" being set to true 
  • Has the same GHASH after being padded to 64 bytes (i.e. 4 blocks) as the original JSON 

To do so, we can fix 2 out of the 4 blocks of the bytearray for it to be a JSON with "admin" set to true, arbitrarily select one of remaining blocks, then compute the final block so that it matches the GHASH, hoping that all four blocks only contain the allowed characters. For example, we can make the bytearray start with {"admin":true,"a and end with ":"abcdefgh"}\x00\x00\x00 since length 61 means \x00\x00\x00 will be padded at the end. Now we can randomly select some 16 byte string using allowed characters and set it as the second block, then compute the third block by matching GHASH to be equal, hoping that the third block also consists of allowed characters and do not interfere with the whole JSON business. While this works, and some people definitely have used this solution to solve, this idea is not very efficient. This is because the probability of success is quite low, and each trial does require some computation. 

 

Solution 2 

 

In my opinion, the cleaner way to solve this challenge is to view the GHASH equation not as a linear equation of blocks, but a linear equation of bits that make up those blocks. Indeed, due to the linear nature of the GHASH, we can actually consider the bytearray as a bit vector, and the GHASH function still keeps its linearity. Therefore, the GHASH equation is just a system of linear equations over $GF(2)$, where the variables are the 512 (64 bytes) bits of the padded bytearray. Let's keep a track of the equations that we have. 

  • Since the equation is over $GF(2^{128})$ we can convert this into 128 linear equations over $GF(2)$.
  • We can fix some characters - I made my padded JSON start with {"admin": true, "a": " and end with "}\x00\x00\x00.
  • This is a total of 27 characters, which is equivalent to 216 fixed bits over $GF(2)$.

Since we have 512 bits of freedom, we can definitely solve this. However, the issue of allowed characters is still there.

To make our random trial work with less trial and error, we add an extra idea - make every character's ASCII value between 64 and 95.

This can be done by forcing the 7th bit to be 0, 6th bit to be 1, and 5th bit to be 0.

  • Since we have 37 characters remaining, this gives us an additional 111 fixed bits over $GF(2)$.

Now $128 + 216 + 111$ is still well below $512$, so now we can just solve this matrix equation, try some random solutions using its kernel basis, and keep going on until we find a good collision for us to solve the challenge. Very fun challenge to work on :)

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
pclass GoodHash:
    def __init__(self, v=b""):
        self.key = b"goodhashGOODHASH"
        self.buf = v
 
    def update(self, v):
        self.buf += v
 
    def digest(self):
        cipher = AES.new(self.key, AES.MODE_GCM, nonce=self.buf)
        enc, tag = cipher.encrypt_and_digest(b"\0" * 32)
        return enc + tag
 
    def hexdigest(self):
        return self.digest().hex()
 
POL = PolynomialRing(GF(2), 'a')
= POL.gen()
= GF(2 ** 128, name = 'a', modulus = a ** 128 + a ** 7 + a ** 2 + a + 1)
 
def aes_enc(p, k):
    cipher = AES.new(key = k, mode = AES.MODE_ECB)
    return cipher.encrypt(p)
 
def int_to_finite(v):
    bin_block = bin(v)[2:].zfill(128)
    res = 0
    for i in range(128):
        res += (a ** i) * int(bin_block[i])
    return F(res)
 
def bytes_to_finite(v):
    v = bytes_to_long(v)
    return int_to_finite(v)
 
def finite_to_int(v):
    v = POL(v)
    res = v.coefficients(sparse = False)
    ret = 0
    for i in range(len(res)):
        ret += int(res[i]) * (1 << (127 - i))
    return ret
 
def finite_to_bytes(v):
    cc = finite_to_int(v)
    return long_to_bytes(cc, blocksize = 16)
 
def hasher(v):
    H = aes_enc(b"\x00" * 16, b"goodhashGOODHASH")
    H_f = bytes_to_finite(H)
    ret = F(0)
    res = bytes_to_long(v)
    bin_block = bin(res)[2:].zfill(512)
    bas = []
    for i in range(512):
        cc = F(a ** int(i % 128)) * F(H_f ** (3 - i // 128)) 
        bas.append(finite_to_int(cc))
        ret += F(a ** int(i % 128)) * F(H_f ** (3 - i // 128)) * int(bin_block[i])
    return bas, finite_to_int(ret)
 
ACCEPTABLE = string.ascii_letters + string.digits + string.punctuation + " "
print(ACCEPTABLE)
 
conn = remote('good-hash.chal.perfect.blue'1337)
body = conn.recvline()[6:-1]
print(body)
print(len(body))
print(conn.recvline())
 
bases, target = hasher(body + b"\x00\x00\x00")
 
starter = b'{"admin": true, "a": "'
finisher = b'"}\x00\x00\x00'
print(len(starter) + len(finisher))
 
print("[+] Building Matrix")
 
SZ = 128 + 37 * 3 + 27 * 8
= Matrix(GF(2), SZ, 512)
vv = []
 
for i in range(128):
    for j in range(512):
        M[i, j] = (bases[j] >> i) & 1
    vv.append((target >> i) & 1)
 
for i in range(37):
    M[3 * i + 1288 * (22 + i)] = 1
    vv.append(0# 128
    M[3 * i + 128 + 18 * (22 + i) + 1= 1
    vv.append(1# 64
    M[3 * i + 128 + 28 * (22 + i) + 2= 1
    vv.append(0# 32
 
for i in range(22):
    for j in range(8):
        M[8 * i + j + 37 * 3 + 1288 * i + j] = 1
        vv.append((int(starter[i]) >> (7 - j)) & 1)
for i in range(5):
    for j in range(8):
        M[8 * i + j + 37 * 3 + 22 * 8 + 1288 * (59 + i) + j] = 1
        vv.append((int(finisher[i]) >> (7 - j)) & 1)
 
vv = vector(GF(2), vv)
val = M.solve_right(vv)
kernels = M.right_kernel().basis()
 
print("[+] Finished Solving Matrix, Finding Collision Now...")
 
attempts = 0
 
while True:
    attempts += 1
    print(attempts)
    cur = val 
    for i in range(len(kernels)):
        cur += (kernels[i] * GF(2)(rand.randint(01)))
    fins = 0
    for i in range(512):
        fins = 2 * fins + int(cur[i])
    fins = long_to_bytes(fins)
    print(fins)
    fins = fins[:61]
    print(fins, len(fins))
    try:
        if len(fins) == 61 and (any(v not in ACCEPTABLE for v in fins.decode()) == False):
            token = json.loads(fins)
            bases2, finresult = hasher(fins + b"\x00\x00\x00")
            print(GoodHash(body + b"\x00\x00\x00").hexdigest())
            print(GoodHash(fins + b"\x00\x00\x00").hexdigest())
            print(target)
            print(finresult)
            print(token)
            if token["admin"== True:
                conn.sendline(fins)
                print(conn.recvline())
                print(conn.recvline())
                break
    except:
        pass
cs

 

 

Yet Another RSA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
 
from Crypto.Util.number import *
import random
 
 
def genPrime():
    while True:
        a = random.getrandbits(256)
        b = random.getrandbits(256)
 
        if b % 3 == 0:
            continue
 
        p = a ** 2 + 3 * b ** 2
        if p.bit_length() == 512 and p % 3 == 1 and isPrime(p):
            return p
 
 
def add(P, Q, mod):
    m, n = P
    p, q = Q
 
    if p is None:
        return P
    if m is None:
        return Q
 
    if n is None and q is None:
        x = m * p % mod
        y = (m + p) % mod
        return (x, y)
 
    if n is None and q is not None:
        m, n, p, q = p, q, m, n
 
    if q is None:
        if (n + p) % mod != 0:
            x = (m * p + 2* inverse(n + p, mod) % mod
            y = (m + n * p) * inverse(n + p, mod) % mod
            return (x, y)
        elif (m - n ** 2) % mod != 0:
            x = (m * p + 2* inverse(m - n ** 2, mod) % mod
            return (x, None)
        else:
            return (NoneNone)
    else:
        if (m + p + n * q) % mod != 0:
            x = (m * p + (n + q) * 2* inverse(m + p + n * q, mod) % mod
            y = (n * p + m * q + 2* inverse(m + p + n * q, mod) % mod
            return (x, y)
        elif (n * p + m * q + 2) % mod != 0:
            x = (m * p + (n + q) * 2* inverse(n * p + m * q + r, mod) % mod
            return (x, None)
        else:
            return (NoneNone)
 
 
def power(P, a, mod):
    res = (NoneNone)
    t = P
    while a > 0:
        if a % 2:
            res = add(res, t, mod)
        t = add(t, t, mod)
        a >>= 1
    return res
 
 
def random_pad(msg, ln):
    pad = bytes([random.getrandbits(8for _ in range(ln - len(msg))])
    return msg + pad
 
 
p, q = genPrime(), genPrime()
= p * q
phi = (p ** 2 + p + 1* (q ** 2 + q + 1)
 
print(f"N: {N}")
 
= getPrime(400)
= inverse(d, phi)
= (e * d - 1// phi
 
print(f"e: {e}")
 
to_enc = input("> ").encode()
ln = len(to_enc)
 
print(f"Length: {ln}")
 
pt1, pt2 = random_pad(to_enc[: ln // 2], 127), random_pad(to_enc[ln // 2 :], 127)
 
= (bytes_to_long(pt1), bytes_to_long(pt2))
= power(M, e, N)
 
print(f"E: {E}")
 
cs

 

The obvious weird part in the script, excluding the whole mysterious group, is that $d$ is very small. 

This leads to some ideas like Wiener's attack or Boneh-Durfee's attack. Since we cannot compute $\phi$ with a very high precision, Wiener's attack does not work well. To be honest, I forgot about Boneh-Durfee and just started googling "Wiener's attack modulo $(p^2+p+1)(q^2+q+1)$". It gave me the paper https://eprint.iacr.org/2021/1160.pdf which had all the ideas and the solution for the problem as well. It also explains where the group comes from. I'll explain this part later. 

 

Since the paper explains the choice of polynomials to use LLL on very well, I implemented them directly and used https://github.com/mimoo/RSA-and-LLL-attacks/blob/master/boneh_durfee.sage instead of defund's black-box (?) script. 

 

The Group

 

I figured this part out before I searched for the paper, but it really doesn't help with solving the challenge.

I started by thinking this was some sort of a curve, but I couldn't really think about the formula. I tried to find the curve formula by taking various monomials of coordinates of each points in the group and using the kernel of the matrix, but it failed as well. (For example, see the "Bonus" from hellman's writeup on CONFidence 2020 Finals https://nbviewer.org/gist/hellman/be17ac7b2363dd0cf6cca89c6a9e69bf)

This meant that this curve might not really be a curve. Now what do we do?

 

Then I looked at the $(m+p+nq)$ part. What could make that sort of a term? After some thought, I found $$(x^2+nx+m)(x^2+qx+p) = x^4 + (n + q)x^3 + (m + p + nq)x^2 + (np + mq)x + mp$$ which looked really suspicious. If we focused on the case where nothing was "None" and $m + p + nq$ is nonzero, we divide $m + p + nq$ to get our final values of $x, y$. This meant that something was done to make things monic. Also, that $2$ and $2(n+q)$ is very suspicious - and now we see that we can divide out by $x^3 - 2$. This gives us $$(m + p + nq)x^2 + (np + mq + 2) x + (mp + 2 (n + q))$$ and making this monic and taking coefficients gives us the $x,  y$ we have from the code. The "None" parts correspond to the cases where the polynomials are not quadratic - they are linear or even a constant. For example, the case where $n, q$ are "None" is equivalent to $(x+m)(x+p) = x^2 + (m+p)x + mp$. The other cases are similar and are left as exercises for the reader.

 

Now we can even compute the group order. If $x^3 - 2$ is irreducible over $GF(p)$, then this is just $GF(p^3)$, but with monic polynomials.

This means that the group size will be $$ \frac{p^3 - 1}{p - 1} = p^2 + p +1$$ which matches the $\phi$ description of the challenge source code.

 

Is $x^3 - 2$ irreducible? It turns out, yes. When $p \equiv 1 \pmod{3}$, results on cubic reciprocity state that $p$ can be uniquely expressed as $p = a^2 + 3b^2$, and $2$ is a cubic reciprocity of $p$ if and only if $b \equiv 0 \pmod{3}$. Check https://en.wikipedia.org/wiki/Cubic_reciprocity. Now we see that our prime generation completely blocks this, which means that $x^3 - 2$ has no solutions over $GF(p)$, hence irreducible. 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import time
 
############################################
# Config
##########################################
 
"""
Setting debug to true will display more informations
about the lattice, the bounds, the vectors...
"""
debug = True
 
"""
Setting strict to true will stop the algorithm (and
return (-1, -1)) if we don't have a correct 
upperbound on the determinant. Note that this 
doesn't necesseraly mean that no solutions 
will be found since the theoretical upperbound is
usualy far away from actual results. That is why
you should probably use `strict = False`
"""
strict = False
 
"""
This is experimental, but has provided remarkable results
so far. It tries to reduce the lattice as much as it can
while keeping its efficiency. I see no reason not to use
this option, but if things don't work, you should try
disabling it
"""
helpful_only = True
dimension_min = 7 # stop removing if lattice reaches that dimension
 
############################################
# Functions
##########################################
 
# display stats on helpful vectors
def helpful_vectors(BB, modulus):
    nothelpful = 0
    for ii in range(BB.dimensions()[0]):
        if BB[ii,ii] >= modulus:
            nothelpful += 1
 
    print(nothelpful, "/", BB.dimensions()[0], " vectors are not helpful")
 
# display matrix picture with 0 and X
def matrix_overview(BB, bound):
    for ii in range(BB.dimensions()[0]):
        a = ('%02d ' % ii)
        for jj in range(BB.dimensions()[1]):
            a += '0' if BB[ii,jj] == 0 else 'X'
            if BB.dimensions()[0< 60:
                a += ' '
        if BB[ii, ii] >= bound:
            a += '~'
        print(a)
 
# tries to remove unhelpful vectors
# we start at current = n-1 (last vector)
def remove_unhelpful(BB, monomials, bound, current):
    # end of our recursive function
    if current == -1 or BB.dimensions()[0<= dimension_min:
        return BB
 
    # we start by checking from the end
    for ii in range(current, -1-1):
        # if it is unhelpful:
        if BB[ii, ii] >= bound:
            affected_vectors = 0
            affected_vector_index = 0
            # let's check if it affects other vectors
            for jj in range(ii + 1, BB.dimensions()[0]):
                # if another vector is affected:
                # we increase the count
                if BB[jj, ii] != 0:
                    affected_vectors += 1
                    affected_vector_index = jj
 
            # level:0
            # if no other vectors end up affected
            # we remove it
            if affected_vectors == 0:
                print("* removing unhelpful vector", ii)
                BB = BB.delete_columns([ii])
                BB = BB.delete_rows([ii])
                monomials.pop(ii)
                BB = remove_unhelpful(BB, monomials, bound, ii-1)
                return BB
 
            # level:1
            # if just one was affected we check
            # if it is affecting someone else
            elif affected_vectors == 1:
                affected_deeper = True
                for kk in range(affected_vector_index + 1, BB.dimensions()[0]):
                    # if it is affecting even one vector
                    # we give up on this one
                    if BB[kk, affected_vector_index] != 0:
                        affected_deeper = False
                # remove both it if no other vector was affected and
                # this helpful vector is not helpful enough
                # compared to our unhelpful one
                if affected_deeper and abs(bound - BB[affected_vector_index, affected_vector_index]) < abs(bound - BB[ii, ii]):
                    print("* removing unhelpful vectors", ii, "and", affected_vector_index)
                    BB = BB.delete_columns([affected_vector_index, ii])
                    BB = BB.delete_rows([affected_vector_index, ii])
                    monomials.pop(affected_vector_index)
                    monomials.pop(ii)
                    BB = remove_unhelpful(BB, monomials, bound, ii-1)
                    return BB
    # nothing happened
    return BB
 
 
def attack(N, e, m, t, X, Y):
    modulus = e
 
    PR.<x, y> = PolynomialRing(ZZ)
    a = N + 1
    b = N * N - N + 1
    f = x * (y * y + a * y + b) + 1
 
    gg = []
    for k in range(0, m+1):
        for i in range(k, m+1):
            for j in range(2 * k, 2 * k + 2):
                gg.append(x^(i-k) * y^(j-2*k) * f^k * e^(m - k))
    for k in range(0, m+1):
        for i in range(k, k+1):
            for j in range(2*k+22*i+t+1):
                gg.append(x^(i-k) * y^(j-2*k) * f^k * e^(m - k))
 
    def order_gg(idx, gg, monomials):
        if idx == len(gg):
            return gg, monomials
 
        for i in range(idx, len(gg)):
            polynomial = gg[i]
            non = []
            for monomial in polynomial.monomials():
                if monomial not in monomials:
                    non.append(monomial)
            
            if len(non) == 1:
                new_gg = gg[:]
                new_gg[i], new_gg[idx] = new_gg[idx], new_gg[i]
 
                return order_gg(idx + 1, new_gg, monomials + non)    
 
    gg, monomials = order_gg(0, gg, [])
 
    # construct lattice B
    nn = len(monomials)
    BB = Matrix(ZZ, nn)
    for ii in range(nn):
        BB[ii, 0= gg[ii](00)
        for jj in range(1, nn):
            if monomials[jj] in gg[ii].monomials():
                BB[ii, jj] = gg[ii].monomial_coefficient(monomials[jj]) * monomials[jj](X, Y)
 
    # Prototype to reduce the lattice
    if helpful_only:
        # automatically remove
        BB = remove_unhelpful(BB, monomials, modulus^m, nn-1)
        # reset dimension
        nn = BB.dimensions()[0]
        if nn == 0:
            print("failure")
            return 0,0
 
    # check if vectors are helpful
    if debug:
        helpful_vectors(BB, modulus^m)
    
    # check if determinant is correctly bounded
    det = BB.det()
    bound = modulus^(m*nn)
    if det >= bound:
        print("We do not have det < bound. Solutions might not be found.")
        print("Try with highers m and t.")
        if debug:
            diff = (log(det) - log(bound)) / log(2)
            print("size det(L) - size e^(m*n) = ", floor(diff))
        if strict:
            return -1-1
    else:
        print("det(L) < e^(m*n) (good! If a solution exists < N^delta, it will be found)")
 
    # display the lattice basis
    if debug:
        matrix_overview(BB, modulus^m)
 
    # LLL
    if debug:
        print("optimizing basis of the lattice via LLL, this can take a long time")
 
    BB = BB.LLL()
 
    if debug:
        print("LLL is done!")
 
    # transform vector i & j -> polynomials 1 & 2
    if debug:
        print("looking for independent vectors in the lattice")
    found_polynomials = False
    
    for pol1_idx in range(nn - 1):
        for pol2_idx in range(pol1_idx + 1, nn):
            # for i and j, create the two polynomials
            PR.<a, b> = PolynomialRing(ZZ)
            pol1 = pol2 = 0
            for jj in range(nn):
                pol1 += monomials[jj](a,b) * BB[pol1_idx, jj] / monomials[jj](X, Y)
                pol2 += monomials[jj](a,b) * BB[pol2_idx, jj] / monomials[jj](X, Y)
 
            # resultant
            PR.<q> = PolynomialRing(ZZ)
            rr = pol1.resultant(pol2)
 
            # are these good polynomials?
            if rr.is_zero() or rr.monomials() == [1]:
                continue
            else:
                print("found them, using vectors", pol1_idx, "and", pol2_idx)
                found_polynomials = True
                break
        if found_polynomials:
            break
 
    if not found_polynomials:
        print("no independant vectors could be found. This should very rarely happen...")
        return 00
    
    rr = rr(q, q)
 
    # solutions
    soly = rr.roots()
 
    if len(soly) == 0:
        print("Your prediction (delta) is too small")
        return 00
 
    soly = soly[0][0]
    ss = pol1(q, soly)
    solx = ss.roots()[0][0]
 
    return solx, soly
 
def inthroot(a, n):
    return a.nth_root(n, truncate_mode=True)[0]
 
= 144256630216944187431924086433849812983170198570608223980477643981288411926131676443308287340096924135462056948517281752227869929565308903867074862500573343002983355175153511114217974621808611898769986483079574834711126000758573854535492719555861644441486111787481991437034260519794550956436261351981910433997
= 3707368479220744733571726540750753259445405727899482801808488969163282955043784626015661045208791445735104324971078077159704483273669299425140997283764223932182226369662807288034870448194924788578324330400316512624353654098480234449948104235411615925382583281250119023549314211844514770152528313431629816760072652712779256593182979385499602121142246388146708842518881888087812525877628088241817693653010042696818501996803568328076434256134092327939489753162277188254738521227525878768762350427661065365503303990620895441197813594863830379759714354078526269966835756517333300191015795169546996325254857519128137848289
= 1 << 400
= 2 * inthroot(Integer(2 * N), 2)
 
res = attack(N, e, 42, X, Y)
print(res) # gives k and p + q, the rest is easy
cs

 

 

Yet Another PRNG

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#!/usr/bin/env python3
 
from Crypto.Util.number import *
import random
import os
from flag import flag
 
def urand(b):
    return int.from_bytes(os.urandom(b), byteorder='big')
 
class PRNG:
    def __init__(self):
        self.m1 = 2 ** 32 - 107
        self.m2 = 2 ** 32 - 5
        self.m3 = 2 ** 32 - 209
        self.M = 2 ** 64 - 59
 
        rnd = random.Random(b'rbtree')
 
        self.a1 = [rnd.getrandbits(20for _ in range(3)]
        self.a2 = [rnd.getrandbits(20for _ in range(3)]
        self.a3 = [rnd.getrandbits(20for _ in range(3)]
 
        self.x = [urand(4for _ in range(3)]
        self.y = [urand(4for _ in range(3)]
        self.z = [urand(4for _ in range(3)]
 
    def out(self):
        o = (2 * self.m1 * self.x[0- self.m3 * self.y[0- self.m2 * self.z[0]) % self.M
 
        self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
        self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
        self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]
 
        return o.to_bytes(8, byteorder='big')
 
if __name__ == "__main__":
    prng = PRNG()
 
    hint = b''
    for i in range(12):
        hint += prng.out()
    
    print(hint.hex())
 
    assert len(flag) % 8 == 0
    stream = b''
    for i in range(len(flag) // 8):
        stream += prng.out()
    
    out = bytes([x ^ y for x, y in zip(flag, stream)])
    print(out.hex())
    
 
cs

 

It turns out that taking the equations and shoving them to CVP repository works. 

https://github.com/rkm0959/Inequality_Solving_with_CVP is very strong :O :O :O 

I've been procrastinating with updating and writing about that repository, very sorry about that....

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def urand(b):
    return int.from_bytes(os.urandom(b), byteorder='big')
 
class PRNGFinisher:
    def __init__(self, X, Y, Z):
        self.m1 = 2 ** 32 - 107
        self.m2 = 2 ** 32 - 5
        self.m3 = 2 ** 32 - 209
        self.M = 2 ** 64 - 59
 
        rnd = rand.Random(b'rbtree')
 
        self.a1 = [rnd.getrandbits(20for _ in range(3)]
        self.a2 = [rnd.getrandbits(20for _ in range(3)]
        self.a3 = [rnd.getrandbits(20for _ in range(3)]
 
        self.x = X
        self.y = Y
        self.z = Z
 
    def out(self):
        o = (2 * self.m1 * self.x[0- self.m3 * self.y[0- self.m2 * self.z[0]) % self.M
 
        self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
        self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
        self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]
 
        return o.to_bytes(8, byteorder='big')
 
class PRNG:
    def __init__(self):
        self.m1 = 2 ** 32 - 107
        self.m2 = 2 ** 32 - 5
        self.m3 = 2 ** 32 - 209
        self.M = 2 ** 64 - 59
 
        rnd = rand.Random(b'rbtree')
 
        self.a1 = [rnd.getrandbits(20for _ in range(3)]
        self.a2 = [rnd.getrandbits(20for _ in range(3)]
        self.a3 = [rnd.getrandbits(20for _ in range(3)]
 
        self.x = [urand(4for _ in range(3)]
        self.y = [urand(4for _ in range(3)]
        self.z = [urand(4for _ in range(3)]
 
    def out(self):
        ret = b''
        xs = []
        ys = []
        zs = []
        for _ in range(12):
            xs.append(self.x[0])
            ys.append(self.y[0])
            zs.append(self.z[0])
            o = (2 * self.m1 * self.x[0- self.m3 * self.y[0- self.m2 * self.z[0]) % self.M
            self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
            self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
            self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]
            ret += o.to_bytes(8, byteorder='big')
        return ret, xs, ys, zs
 
 
# Directly taken from rbtree's LLL repository
# From https://oddcoder.com/LOL-34c3/, https://hackmd.io/@hakatashi/B1OM7HFVI
def Babai_CVP(mat, target):
    M = mat.BKZ(block_size = 35)
    G = M.gram_schmidt()[0]
    diff = target
    for i in reversed(range(G.nrows())):
        diff -=  M[i] * ((diff * G[i]) / (G[i] * G[i])).round()
    return target - diff
 
def solve(mat, lb, ub, weight = None):
    num_var  = mat.nrows()
    num_ineq = mat.ncols()
 
    max_element = 0 
    for i in range(num_var):
        for j in range(num_ineq):
            max_element = max(max_element, abs(mat[i, j]))
 
    if weight == None:
        weight = num_ineq * max_element
 
    # sanity checker
    if len(lb) != num_ineq:
        print("Fail: len(lb) != num_ineq")
        return
 
    if len(ub) != num_ineq:
        print("Fail: len(ub) != num_ineq")
        return
 
    for i in range(num_ineq):
        if lb[i] > ub[i]:
            print("Fail: lb[i] > ub[i] at index", i)
            return
 
    # heuristic for number of solutions
    DET = 0
 
    if num_var == num_ineq:
        DET = abs(mat.det())
        num_sol = 1
        for i in range(num_ineq):
            num_sol *= (ub[i] - lb[i])
        if DET == 0:
            print("Zero Determinant")
        else:
            num_sol //= DET
            # + 1 added in for the sake of not making it zero...
            print("Expected Number of Solutions : ", num_sol + 1)
 
    # scaling process begins
    max_diff = max([ub[i] - lb[i] for i in range(num_ineq)])
    applied_weights = []
 
    for i in range(num_ineq):
        ineq_weight = weight if lb[i] == ub[i] else max_diff // (ub[i] - lb[i])
        applied_weights.append(ineq_weight)
        for j in range(num_var):
            mat[j, i] *= ineq_weight
        lb[i] *= ineq_weight
        ub[i] *= ineq_weight
 
    # Solve CVP
    target = vector([(lb[i] + ub[i]) // 2 for i in range(num_ineq)])
    result = Babai_CVP(mat, target)
 
    for i in range(num_ineq):
        if (lb[i] <= result[i] <= ub[i]) == False:
            print("Fail : inequality does not hold after solving")
    
    # recover x
    fin = None
 
    if DET != 0:
        mat = mat.transpose()
        fin = mat.solve_right(result)
    
    ## recover your result
    return result, applied_weights, fin
 
def get_idx(name, v):
    if name == 'x':
        return v - 1
    if name == 'y':
        return v + 11
    if name == 'z':
        return v + 23
 
test = False
 
if test:
    prng = PRNG()
    hint, ERRX, ERRZ, XS, YS, ZS = prng.out()
    print("XS", XS)
    print("YS", YS)
    print("ZS", ZS)
 
    vec_sol = []
    for i in range(12):
        vec_sol.append(XS[i])
    for i in range(12):
        vec_sol.append(YS[i])
    for i in range(12):
        vec_sol.append(ZS[i])
else:
    prng = PRNG()
    hint = '67f19d3da8af1480f39ac04f7e9134b2dc4ad094475b696224389c9ef29b8a2aff8933bd3fefa6e0d03827ab2816ba0fd9c0e2d73e01aa6f184acd9c58122616f9621fb8313a62efb27fb3d3aa385b89435630d0704f0dceec00fef703d54fca'
    output = '153ed807c00d585860b843a03871b11f60baf11fe72d2619283ec5b4d931435ac378e21abe67c47f7923fcde101f4f0c65b5ee48950820f9b26e33acf57868d5f0cbc2377a39a81918f8c20f61c71047c8e82b1c965fa01b58ad0569ce7521c7'
    hint = bytes.fromhex(hint)
    output = bytes.fromhex(output)
 
print(len(hint))
= Matrix(ZZ, 7575)
 
cnt = 0
tot_base = 36
 
lb = []
ub = []
 
# x
for i in range(9):
    M[get_idx('x', i + 4), cnt] = 1
    M[get_idx('x', i + 1), cnt] = -prng.a1[0]
    M[get_idx('x', i + 2), cnt] = -prng.a1[1]
    M[get_idx('x', i + 3), cnt] = -prng.a1[2]
    M[tot_base, cnt] = prng.m1
    cnt += 1
    tot_base += 1
    lb.append(0)
    ub.append(0)
 
# y 
for i in range(9):
    M[get_idx('y', i + 4), cnt] = 1
    M[get_idx('y', i + 1), cnt] = -prng.a2[0]
    M[get_idx('y', i + 2), cnt] = -prng.a2[1]
    M[get_idx('y', i + 3), cnt] = -prng.a2[2]
    M[tot_base, cnt] = prng.m2
    cnt += 1
    tot_base += 1
    lb.append(0)
    ub.append(0)
 
# z
for i in range(9):
    M[get_idx('z', i + 4), cnt] = 1
    M[get_idx('z', i + 1), cnt] = -prng.a3[0]
    M[get_idx('z', i + 2), cnt] = -prng.a3[1]
    M[get_idx('z', i + 3), cnt] = -prng.a3[2]
    M[tot_base, cnt] = prng.m3
    cnt += 1
    tot_base += 1
    lb.append(0)
    ub.append(0)
 
for i in range(12):
    M[get_idx('x', i + 1), cnt] = 1
    cnt += 1
    lb.append(0)
    ub.append(1 << 32)
 
for i in range(12):
    M[get_idx('y', i + 1), cnt] = 1
    cnt += 1
    lb.append(0)
    ub.append(1 << 32)
 
for i in range(12):
    M[get_idx('z', i + 1), cnt] = 1
    cnt += 1
    lb.append(0)
    ub.append(1 << 32)
 
for i in range(12):
    M[get_idx('x', i + 1), cnt] = (2 * prng.m1)
    M[get_idx('y', i + 1), cnt] = -prng.m3
    M[get_idx('z', i + 1), cnt] = -prng.m2
    M[tot_base, cnt] = prng.M
    cnt += 1
    tot_base += 1
    val = bytes_to_long(hint[8 * i : 8 * i + 8])
    lb.append(val)
    ub.append(val)
 
print(cnt)
print(tot_base)
 
result, applied_weights, fin = solve(M, lb, ub)
 
INIT_X = [int(fin[get_idx('x', i + 1)]) for i in range(3)]
INIT_Y = [int(fin[get_idx('y', i + 1)]) for i in range(3)]
INIT_Z = [int(fin[get_idx('z', i + 1)]) for i in range(3)]
 
print(fin)
print(INIT_X)
print(INIT_Y)
print(INIT_Z)
 
actual_prng = PRNGFinisher(INIT_X, INIT_Y, INIT_Z)
 
hint_check = b''
for i in range(12):
    hint_check += actual_prng.out()
 
sdaf = [hint_check[i] == hint[i] for i in range(96)]
print(sdaf)
 
if test == False:
    flag = b''
    for i in range(len(output) // 8):
        res = bytes_to_long(actual_prng.out())
        res = res ^ bytes_to_long(output[8 * i : 8 * i + 8])
        flag += long_to_bytes(res)
    print(flag)
cs

 

 

Seed Me

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import java.nio.file.Files;
import java.nio.file.Path;
import java.io.IOException;
import java.util.Random;
import java.util.Scanner;
 
class Main {
 
    private static void printFlag() {
        try {
            System.out.println(Files.readString(Path.of("flag.txt")));
        }
        catch(IOException e) {
            System.out.println("Flag file is missing, please contact admins");
        }
    }
 
    public static void main(String[] args) {
        int unlucky = 03777;
        int success = 0;
        int correct = 16;
 
        System.out.println(unlucky);
 
        System.out.println("Welcome to the 'Lucky Crystal Game'!");
        System.out.println("Please provide a lucky seed:");
        Scanner scr = new Scanner(System.in);
        long seed = scr.nextLong();
        Random rng = new Random(seed);
 
        for(int i=0; i<correct; i++) {
            /* Throw away the unlucky numbers */
            for(int j=0; j<unlucky; j++) {
                rng.nextFloat();
            }
 
            /* Do you feel lucky? */
            if (rng.nextFloat() >= (7.331f*.1337f)) {
                success++;
            }
        }
 
        if (success == correct) {
            printFlag();
        }
        else {
            System.out.println("Unlucky!");
        }
    }
}
 
cs

 

Java's RNG is truncated LCG, but to be honest it's not even truncated as it is pretty much LCG result divided by $2^{48}$. 

This is ultimately a hidden number problem, so it must be lattices - and CVP repository should work.

However, naively plugging in the lower bound / upper bound vectors gives some results that are off. 

To solve this problem, we manually change the lower bound / upper bound by hand to "persuade" our CVP algorithm to make the results more appropriate for our liking. For example, if one result is 0.97, smaller than we need, then we can make the lower bound a bit larger. If one result is 0.01, which means that we overshot the value, we can reduce the upper bound so that the value can land between 0.98 and 1.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Directly taken from rbtree's LLL repository
# From https://oddcoder.com/LOL-34c3/, https://hackmd.io/@hakatashi/B1OM7HFVI
def Babai_CVP(mat, target):
    M = IntegerLattice(mat, lll_reduce=True).reduced_basis
    G = M.gram_schmidt()[0]
    diff = target
    for i in reversed(range(G.nrows())):
        diff -=  M[i] * ((diff * G[i]) / (G[i] * G[i])).round()
    return target - diff
 
 
def solve(mat, lb, ub, weight = None):
    num_var  = mat.nrows()
    num_ineq = mat.ncols()
 
    max_element = 0 
    for i in range(num_var):
        for j in range(num_ineq):
            max_element = max(max_element, abs(mat[i, j]))
 
    if weight == None:
        weight = num_ineq * max_element
 
    # sanity checker
    if len(lb) != num_ineq:
        print("Fail: len(lb) != num_ineq")
        return
 
    if len(ub) != num_ineq:
        print("Fail: len(ub) != num_ineq")
        return
 
    for i in range(num_ineq):
        if lb[i] > ub[i]:
            print("Fail: lb[i] > ub[i] at index", i)
            return
 
        # heuristic for number of solutions
    DET = 0
 
    if num_var == num_ineq:
        DET = abs(mat.det())
        num_sol = 1
        for i in range(num_ineq):
            num_sol *= (ub[i] - lb[i])
        if DET == 0:
            print("Zero Determinant")
        else:
            num_sol //= DET
            # + 1 added in for the sake of not making it zero...
            print("Expected Number of Solutions : ", num_sol + 1)
 
    # scaling process begins
    max_diff = max([ub[i] - lb[i] for i in range(num_ineq)])
    applied_weights = []
 
    for i in range(num_ineq):
        ineq_weight = weight if lb[i] == ub[i] else max_diff // (ub[i] - lb[i])
        applied_weights.append(ineq_weight)
        for j in range(num_var):
            mat[j, i] *= ineq_weight
        lb[i] *= ineq_weight
        ub[i] *= ineq_weight
 
    # Solve CVP
    target = vector([(lb[i] + ub[i]) // 2 for i in range(num_ineq)])
    result = Babai_CVP(mat, target)
 
    for i in range(num_ineq):
        if (lb[i] <= result[i] <= ub[i]) == False:
            print("Fail : inequality does not hold after solving")
            break
    
        # recover x
    fin = None
 
    if DET != 0:
        mat = mat.transpose()
        fin = mat.solve_right(result)
    
    ## recover your result
    return result, applied_weights, fin
 
# conn = remote('seedme.chal.perfect.blue', 1337)
# conn.interactive()
 
def getv(seed):
    seed = (seed * 0x5DEECE66D + 0xB& ((1 << 48- 1)
    return seed, (seed >> 24/ (1 << 24)
 
curm = [1]
curb = [0]
 
= Matrix(ZZ, 1717)
lb = [0* 17
ub = [0* 17
 
for i in range(16 * 2048):
    curm.append((0x5DEECE66D * curm[i]) % (1 << 48))
    curb.append((0x5DEECE66D * curb[i] + 0xB) % (1 << 48))
 
for i in range(016):
    m, b = curm[2048 * i + 2048], curb[2048 * i + 2048]
    M[0, i] = m
    M[i + 1, i] = 1 << 48
    lb[i] = int(0.9803 * (1 << 48)) - b 
    ub[i] = int((1 << 48)) - 1 - b
 
# post-fix manually
lb[0= int(0.985 * (1 << 48)) - curb[2048]
ub[15= int(0.995 * (1 << 48)) - curb[2048 * 16]
 
M[016= 1
lb[16= 0
ub[16= 1 << 48
 
result, applied_weights, fin = solve(M, lb, ub)
 
res = (int(fin[0]) + (1 << 48)) % (1 << 48)
 
init_seed = 0x5DEECE66D ^ res 
 
print(init_seed)
 
seeds = init_seed
seeds = (seeds ^ 0x5DEECE66D& ((1 << 48- 1)
 
curm = [1]
curb = [0]
 
for i in range(16 * 2048):
    curm.append((0x5DEECE66D * curm[i]) % (1 << 48))
    curb.append((0x5DEECE66D * curb[i] + 0xB) % (1 << 48))
 
for i in range(016):
    m, b = curm[2048 * i + 2048], curb[2048 * i + 2048]
    res = (seeds * m + b) % (1 << 48)
    print(res / (1 << 48>= 0.7331 * 1.337)
cs

 

 

'수학 > 암호론 및 CTF' 카테고리의 다른 글

SECCON CTF 2021 Writeups  (0) 2021.12.14
N1CTF 2021 Writeups  (1) 2021.11.22
PBCTF 2021 Writeups  (0) 2021.10.13
TSGCTF 2021 Writeups  (0) 2021.10.03
DUCTF 2021 Writeups  (0) 2021.09.26
ACSC Crypto Writeups  (0) 2021.09.26