I made one crypto challenge each for CODEGATE CTF 2022 quals and finals. In this article, I’m gonna explain how the interpolation attack works, and show the way to solve my challenge “Hidden Command Service” from the finals.

Interpolation Attack

Once upon a time, there was a dangerous attack technique called ‘differential cryptanalysis’… Then, two crypto-magicians made a scheme which is provably safe from the differential cryptanalysis, called KN-Cipher.

However, it soon turned out that the structure of S-boxes used in those ciphers are vulnerable to a new attack called ‘interpolation attack’, and a variant of KN-Cipher called PURE is broken.

The S-box used in PURE is basically $S(x) = x^3$ in $\text{GF}(2^{32})$, usually used as $(x + k)^3$ for an input $x$ and a subkey $k$. The key idea of the interpolation attack is coming from here: when we use a service, the key of a cipher is fixed to a value we don’t know. However, whatever it’s fixed, so we can consider subkeys as unknown constant values. If the cipher only consists of operations over a finite field like $x^3$ over $\text{GF}(2^{32})$, then we can presume that the output is a polynomial of the input, of which coefficients are unknown.

For example, let’s think 4-round cipher like this (all variables and operations are over a finite field):

$x_0 = \text{input}, k_i = \text{subkey}, x_i = (x_{i-1} + k_i)^3, x_4 = \text{output}$

We don’t know $k_i$, but we know that $x_4$ can be interpreted as a 81-degree polynomial of $x_0$. We can recover the coefficients of the polynomial with 82 plaintext-ciphertext pairs. The name of the attack comes from here: the Lagrange interpolation.

Hidden Command Service

Here’s the code of Hidden Command Service.

#!/usr/bin/python3
from os import urandom as random
from hashlib import sha256
from time import time
from subprocess import check_output, DEVNULL


def gf_mul(a, b):
    res = 0
    for i in range(16):
        res <<= 1
        if res & 0x10000:
            res ^= 0x15A55
        if b & (1 << (15 - i)):
            res ^= a
    return res


def gf_pow(v, i):
    if i == 1:
        return v
    return gf_mul(gf_pow(v, i - 1), v)


class BlockCipher:
    # There are various ways to block differential cryptanalysis.
    # Kaisa Nyberg and Lars Knudsen proved this s-box structure is
    # safe from differential cryptanalysis attacks.
    SBOX = [gf_pow(v, 3) ^ 3 for v in range(2**16)]
    BLOCK_SIZE = 8
    KEY_SIZE = 32
    NUM_ROUND = 4

    def __init__(self, key):
        assert type(key) == bytes
        assert len(key) == self.KEY_SIZE
        self.rkey = []
        for i in range(0, self.KEY_SIZE, 2):
            self.rkey.append(int.from_bytes(key[i : i + 2], "little"))

    def encrypt_block(self, block):
        block = [
            int.from_bytes(block[i : i + 2], "little")
            for i in range(0, self.BLOCK_SIZE, 2)
        ]

        for rnd in range(self.NUM_ROUND):
            for idx in range(self.BLOCK_SIZE // 2):
                block[idx] ^= self.SBOX[
                    block[(idx + 1) % (self.BLOCK_SIZE // 2)]
                    ^ self.rkey[(self.BLOCK_SIZE // 2) * rnd + idx]
                ]

        return b"".join(v.to_bytes(2, "little") for v in block)

    def decrypt_block(self, block):
        block = [
            int.from_bytes(block[i : i + 2], "little")
            for i in range(0, self.BLOCK_SIZE, 2)
        ]
        for rnd in reversed(range(self.NUM_ROUND)):
            for idx in reversed(range(self.BLOCK_SIZE // 2)):
                block[idx] ^= self.SBOX[
                    block[(idx + 1) % (self.BLOCK_SIZE // 2)]
                    ^ self.rkey[(self.BLOCK_SIZE // 2) * rnd + idx]
                ]
        return b"".join(v.to_bytes(2, "little") for v in block)

    @classmethod
    def _pad(cls, b):
        v = cls.BLOCK_SIZE - len(b) % cls.BLOCK_SIZE
        return b + bytes([v] * v)

    @classmethod
    def _unpad(cls, b):
        if not b:
            # TODO: Define an exception and raise here
            return b""
        if not (1 <= b[-1] <= cls.BLOCK_SIZE):
            # TODO: Define an exception and raise here
            return b""
        return b[: -b[-1]]

    @staticmethod
    def _xor(a, b):
        return bytes(x ^ y for x, y in zip(a, b))

    def encrypt(self, pt):
        mac = sha256(pt).digest()[: self.BLOCK_SIZE]
        pt = self._pad(pt) + mac
        ct = random(self.BLOCK_SIZE)
        for idx in range(0, len(pt), self.BLOCK_SIZE):
            block = pt[idx : idx + self.BLOCK_SIZE]
            ct += self.encrypt_block(self._xor(block, ct[-self.BLOCK_SIZE :]))
        return ct

    def decrypt(self, ct):
        pt = b""
        for idx in range(self.BLOCK_SIZE, len(ct), self.BLOCK_SIZE):
            iv, block = ct[idx - self.BLOCK_SIZE : idx], ct[idx : idx + self.BLOCK_SIZE]
            pt += self._xor(iv, self.decrypt_block(block))
        if len(pt) < 2 * self.BLOCK_SIZE:
            return b""
        padded, mac = pt[: -self.BLOCK_SIZE], pt[-self.BLOCK_SIZE :]
        pt = self._unpad(padded)
        if (not pt) or sha256(pt).digest()[: self.BLOCK_SIZE] != mac:
            return b""
        return pt


def run_command(cmd):
    print("Result:")
    try:
        print(check_output(cmd, stdin=DEVNULL, shell=True).decode())
    except:
        pass


def main():
    seed = random(16)
    print("Welcome to the hidden command service!")
    print(f"The seed of this time is: {seed.hex()}")

    with open("./password", "rb") as f:
        password = f.read()

    key = sha256(seed + password).digest()
    cipher = BlockCipher(key)

    for _ in range(500):
        inp = input("> ")
        if inp == "emergency":
            print("[EMERGENCY MODE]")
            print("0. Get the server time (Testing purpose!)")
            print("1. Get the target info")

            inp = input("> ")
            if inp == "0":
                cmd = "date"
            elif inp == "1":
                cmd = "cat target_info"
            else:
                print("Wrong input :(")
                continue

            run_command(cmd)
            enc_cmd = cipher.encrypt(cmd.encode())
            print(f"Don't forget the encrypted command: {enc_cmd.hex()}")
            continue
        elif inp == "exit":
            print("Bye!")
            return

        try:
            enc_cmd = bytes.fromhex(inp.strip())
        except ValueError:
            print("Wrong input :(")
            continue

        if len(enc_cmd) % cipher.BLOCK_SIZE:
            print("Wrong input size :(")
            continue

        cmd = cipher.decrypt(enc_cmd)
        try:
            cmd = cmd.decode()
        except:
            cmd = ""

        run_command(cmd)

    print("There's a trial limit on each seed.")
    print("Please connect again. Bye!")


if __name__ == "__main__":
    main()

I made a mistake in _unpad(). It’s possible to solve without exploiting the cipher itself. Try that out too ;)

Check output polynomials

The cipher has four input values over $\text{GF}(2^{16})$, and it does $(x + k)^3 + 3$ for unknown $k$s. Hmm, how many monomials will be there, inside the outputs values?

We can easily check with SageMath:

F.<x> = GF(2^16, modulus=x^16+x^14+x^12+x^11+x^9+x^6+x^4+x^2+1)

P.<i0, i1, i2, i3> = PolynomialRing(F)

f = [i0, i1, i2, i3]
for i in range(4):
    for j in range(4):
        # FIXME: This may not work sometimes
        f[j] += (f[(j + 1) % 4] + F.fetch_int(randint(1, 2^16 - 1))) ^ 3 + F.fetch_int(3)
print([len(v.monomials()) for v in f])

If you run this code, you will soon realize that it takes too much time, and the result has too many monomials. How can we reduce the number of monomials?

Meet in the Middle

Think in this way: we have 8 variables $i_0, i_1, i_2, i_3$ (inputs) and $o_0, o_1, o_2, o_3$ (outputs). Define $f$ doing 3 rounds forward, and $g$ doing 1 round backward. Then we can get $f(i_0, i_1, i_2, i_3) = g(o_0, o_1, o_2, o_3)$.

P.<i0, i1, i2, i3, o0, o1, o2, o3> = PolynomialRing(F)

f = [i0, i1, i2, i3]
for i in range(3):
    for j in range(4):
        # FIXME: This may not work sometimes
        f[j] += (f[(j + 1) % 4] + F.fetch_int(randint(1, 2^16 - 1))) ^ 3 + F.fetch_int(3)
print([len(v.monomials()) for v in f])

g = [o0, o1, o2, o3]
for j in reversed(range(4)):
    # FIXME: This may not work sometimes
    g[j] += (g[(j + 1) % 4] + F.fetch_int(randint(1, 2^16 - 1))) ^ 3 + F.fetch_int(3)
print([len(v.monomials()) for v in g])

The result is [106, 726, 1254, 1306] and [543, 80, 18, 5]. It must be doable now!

Let’s solve!

Calculate monomials

First, let’s calculate possible monomials.

from pwn import *
from tqdm import tqdm
from hashlib import sha256
import sys
from multiprocessing import Pool

IP = sys.argv[1]
PORT = int(sys.argv[2])

F.<x> = GF(2^16, modulus=x^16+x^14+x^12+x^11+x^9+x^6+x^4+x^2+1)

P.<i0, i1, i2, i3, o0, o1, o2, o3> = PolynomialRing(F)

f = [i0, i1, i2, i3]
for i in range(3):
    for j in range(4):
        # FIXME: This may not work sometimes
        f[j] += (f[(j + 1) % 4] + F.fetch_int(randint(1, 2^16 - 1))) ^ 3 + F.fetch_int(3)
f = [v.monomials()[:-1] for v in f]

g = [o0, o1, o2, o3]
for j in reversed(range(4)):
    # FIXME: This may not work sometimes
    g[j] += (g[(j + 1) % 4] + F.fetch_int(randint(1, 2^16 - 1))) ^ 3 + F.fetch_int(3)
g = [v.monomials()[:-1] for v in g]

num_monomials = 0
for i in range(4):
    ln = len(f[i]) + len(g[i])
    num_monomials = max(num_monomials, ln)

Get pt-ct pairs

Get some plaintext-ciphertext pairs:

r = remote(IP, PORT)

cmd = b"cat target_info\x01" + sha256(b"cat target_info").digest()[:8]

def xor(x, y):
    return bytes([a ^^ b for a, b in zip(x, y)])

pairs = []
print("Getting input/outputs")
for i in tqdm(range(num_monomials // 3 + 2)):
    r.sendlineafter(b'> ', b'emergency')
    r.sendlineafter(b'> ', b'1')

    r.recvuntil(b'command: ')
    res = bytes.fromhex(r.recvline().strip().decode())

    pairs.append( (xor(cmd[:8], res[:8]), res[8:16]) )
    pairs.append( (xor(cmd[8:16], res[8:16]), res[16:24]) )
    pairs.append( (xor(cmd[16:], res[16:24]), res[24:]) )

args = []
for x, y in pairs:
    arr = []
    for i in range(0, 8, 2):
        arr.append(F.fetch_int(int.from_bytes(x[i:i+2], 'little')))
    for i in range(0, 8, 2):
        arr.append(F.fetch_int(int.from_bytes(y[i:i+2], 'little')))
    args.append(arr)

print("Getting input/outputs done")

Recover $f$ and $g$ except the constant coefficient

Now we can calculate both $f$ and $g$! But wait, it feels like we’re missing something. We only have $f(i) - g(o) = 0$. As $f$ and $g$ have different monomials, it mostly does not matter… except the constant coefficient. Uh, let’s calculate other coefficients first, using the fact that the coefficient of a monomial with the highest degree is $1$.

f_eqs = []
g_eqs = []

for idx in range(4):
    monomials = f[idx] + g[idx]
    mat = []
    vec = []
    print(f"Retriving matrix for f{idx} & g{idx}")

    for arg in tqdm(args):
        # The max-degree monomial always has the coefficient 1
        vec.append(monomials[0](*arg))

        def for_multiproc(x):
            return x(*arg)

        with Pool(4) as pool:
            row = pool.map(for_multiproc, monomials[1:])
        row.append(1) # constant coefficient
        mat.append(row)

    print(f"Solve linear system for f{idx} & g{idx}")

    mat = Matrix(mat)
    vec = vector(vec)
    # This takes too much time
    # assert mat.right_kernel().dimension() == 0
    res = mat.solve_right(vec)
    print("Solved the linear system")

    f_eq = f[idx][0]
    for i in range(1, len(f[idx])):
        f_eq += res[i - 1] * f[idx][i]
    f_eq += res[-1]
    f_eqs.append(f_eq)

    g_eq = 0
    for i in range(len(g[idx])):
        g_eq += res[i + len(f[idx]) - 1] * g[idx][i]
    g_eqs.append(g_eq)

    print(f"Got f{idx}(x) and g{idx}(x)")

Recover full $f$ and $g$

Now, we have to recover the constant coefficient. But how?

We will recover the last four subkeys in this case. As we know some coefficients of $g$, we can guess the subkeys as:

$g_3 = ((o_0 + k_{-1})^3 + 3 + o_3)$, so the coefficient of $o_0^2$ must be $k_{-1}$.

$g_2 = ((g_3 + k_{-2})^3 + 3 + o_2)$, so the coefficient of $o_3^2$ must be $k_{-2}$. …

print("Getting keys from g*(x)")

keys = []
os = [o0, o1, o2, o3]
const = 0
for i in reversed(range(4)):
    k = g_eqs[i].monomial_coefficient(os[(i + 1) % 4]^2) + const
    keys.append(k)
    const = (const + k) ^ 3 + F.fetch_int(3)
keys = keys[::-1]

print("Got keys, getting constant coefficients of f*(x)")

g = [o0, o1, o2, o3]
for i in reversed(range(4)):
    g[i] += (g[(i + 1) % 4] + keys[i]) ^ 3 + F.fetch_int(3)
    res = (g_eqs[i] + g[i])
    f_eqs[i] += res

print("Got the constant coefficients")

Re-build the cipher

Finally, we can re-build the cipher without knowing the key.

def encrypt_block(block):
    block = [F.fetch_int(int.from_bytes(block[i : i + 2], "little")) for i in range(0, 8, 2)]
    block += [0] * 4

    block = [f_eqs[i](*block) for i in range(4)]
    for i in range(4):
        block[i] += (block[(i + 1) % 4] + keys[i]) ^ 3 + F.fetch_int(3)
    
    block = [int(v.integer_representation()).to_bytes(2, "little") for v in block]
    return b''.join(block)

def encrypt(b):
    v = 8 - len(b) % 8
    padded = b + bytes([v] * v) + sha256(b).digest()[:8]
    ct = b'\x00' * 8
    for idx in range(0, len(padded), 8):
        block = padded[idx : idx + 8]
        ct += encrypt_block(xor(block, ct[-8 :]))
    return ct

print("Test encryption")
for inp, out in pairs[:10]:
    assert encrypt_block(inp) == out
print("Test encryption passed")

Send your commands freely!

print("Finished!!!\n\n\n")
r.recvuntil(b'> ')
print("Enter commands like 'ls', 'cat flag'")
while True:
    inp = input("Command: ")
    r.sendline(encrypt(inp.strip().encode()).hex().encode())
    print(r.recvuntil(b'> ').decode())

You can find the file named __wow_hidden_flag__ inside the directory. The flag is codegate2022{I_am_the_Apex_Predator!SHAAAARK!!!}. :D

Conclusion

Interpolation attacks can be used widely, as even $x^{-1}$ over finite fields can be rewritten as $x^k$, $x^{254}$ over $\text{GF}(2^8)$ for example. Even if it might not be that practical for strong block ciphers like AES, it gives nice insights on S-box structures.

Please check out the original paper if you wanna learn more.