"""Cryptodome style class for naive/textbook RSA"""

import math
from Crypto.PublicKey import RSA
from Crypto.Util.number import getPrime, inverse


class TextbookRSACipher(object):
    """Remplacement de PKCS.new(key) pour textbook RSA.
    Alternative naive à cipher = PKCS.new(priv_key)"""

    @classmethod
    def construct(cls, k_size: int, e = 2**16 + 1):
        """Generate a new RSA key pair. Compatibility wrapper for RSA.generate(k_size).
        Returns a tuple (n, e, d) where n is the modulus, e is the public exponent and d is the private exponent.
        May raise an ArithmeticError if e is not coprime with phi(n). A proprer implementation should not throw this error this way.
        """
        p, q = getPrime(k_size // 2), getPrime(k_size // 2)
        n = p * q
        phi = (p - 1) * (q - 1)
        if math.gcd(e, phi) != 1:
            raise ArithmeticError(f"{e=} is NOT coprime with {phi=}")
        d = inverse(e, phi)

        return (n, e, d)

    @classmethod
    def new(cls, key):
        """Create a new TextbookRSACipher object. Compatibility wrapper for PKCS.new(key)"""
        return TextbookRSACipher(key)

    def __init__(self, key: RSA.RsaKey):
        self._key = key

    def encrypt(self, msg: bytes):
        """textbook RSA cipher"""
        n, e = self._key.n, self._key.e
        m: int = int.from_bytes(msg, "big")
        if m > n:
            raise ValueError(
                f"Cleartext {m.bit_length()} bits is too long (max={n.bit_length()})"
            )
        c = pow(m, e, n)
        size = math.ceil(c.bit_length() / 8)
        return c.to_bytes(size, "big")

    def decrypt(self, msg: bytes):
        """textbook RSA cipher"""
        n, d = self._key.n, self._key.d
        c = int.from_bytes(msg, "big")
        if c > n:
            raise ValueError(
                f"Ciphertext {c.bit_length()} bits is too long (max={n.bit_length()})"
            )
        m = pow(c, d, n)
        size = math.ceil(m.bit_length() / 8)
        return m.to_bytes(size, "big")


if __name__ == "__main__":
    # Exemple d'utilisation
    secret_key = RSA.generate(1024)
    cipher = TextbookRSACipher(secret_key)
    cleartext = b"Attack at dawn!"

    ciphered = cipher.encrypt(cleartext)
    message = cipher.decrypt(ciphered)
    assert message == cleartext

    print(TextbookRSACipher.construct(128))
