"""[MIF29] TP-Oracle : attaque à l'oracle de parité

Voir <http://mif29-crypto-sec.pages.univ-lyon1.fr/TP-oracle/>
"""

import binascii
import math
import json
from pathlib import Path

import requests
from Crypto.PublicKey import RSA

HOST = "challenge-64.mif29.os.univ-lyon1.fr"
PORT = 8080
URL = f"http://{HOST}:{PORT}/parity/"

# TODO : changer l'URL pour l'attaque, utiliser votre VM
URL_EXAMPLE = "http://mif29-crypto-sec.pages.univ-lyon1.fr/TP4-Oracle/example.b64"
KEY_SIZE = 2**10

CIPHERTEXT_FILE = Path("ciphertext.b64")
KEYFILE_PUB = Path("rsa_parity_public.pem")

# objet global pour les requêtes HTTP
session = requests.Session()


## Préparation : téléchargement si nécessaire du message et de la clef publique


def download_content(filepath, uri):
    """Download HTTP content into a file"""
    if not filepath.exists():
        req = requests.get(uri, timeout=3.0)
        with open(filepath, "wb") as f:
            f.write(req.content)
            print(f"File {filepath} written (first download)")


download_content(CIPHERTEXT_FILE, URL + "ciphered")
download_content(KEYFILE_PUB, URL + "pubkey")

## Chargement du message et de la clef publique
with open(CIPHERTEXT_FILE, "rb") as file:
    original_ciphertext = file.read()
    print(f"File {CIPHERTEXT_FILE} read")

with open(KEYFILE_PUB, "r", encoding="ascii") as file:
    public_key = RSA.import_key(file.read())
    print(f"File {KEYFILE_PUB} read ({public_key.e=})")


## helpers
def get_parity(message: bytes) -> bool:
    """get parity of cleartext from b64 encoded message, True iff cleartext is odd"""
    json_body = json.dumps({"base64" : message.decode("utf-8")})
    http_answer = session.post(URL, data=json_body , timeout=1.0)
    if http_answer.status_code not in (200,):
        raise ValueError(f"{http_answer.status_code} ({http_answer.reason}): {http_answer.content}")
    json_content = http_answer.json()
    print(json_content)
    if json_content["answer"] not in ("Even", "Odd"):
        raise ValueError(f"Unknown answer {http_answer.content=}")
    return json_content["answer"] == "Odd"


def encrypt(pubkey, m: bytes):
    """textbook RSA cipher: takes a b64 message, encrypt then return b64"""
    n, e = pubkey.n, pubkey.e
    msg = int.from_bytes(binascii.a2b_base64(m), "big", signed=False)
    if msg > n:
        raise ValueError(f"Cleartext {msg.bit_length()} bits is too long")
    ciphered = pow(msg, e, n)
    size = math.ceil(ciphered.bit_length() / 8)
    return binascii.b2a_base64(ciphered.to_bytes(size, "big"), newline=False)


# TODO
def multiply_cleartext_by_2(pubkey, m: bytes):
    """read b64, converts to int, multiply by 2 then convert back to b64"""
    # TODO : compléter la fonction en reprenant encrypt()
    return NotImplementedError


def find_message(pubkey, base_message: bytes):
    """Dichomoty to find original message bit by bits"""
    # intervalle de départ [low, high[
    low, high = 0, (2**KEY_SIZE)
    # TODO : compléter ici la recherche dichotomique en utilisant multiply_cleartext_by_2
    return NotImplementedError


if __name__ == "__main__":
    cleartext_as_bytes = "attack at dawn!".encode("utf-8")
    cleartext_as_base64 = binascii.b2a_base64(cleartext_as_bytes, newline=False)
    print(f"Cleartext={cleartext_as_bytes}")
    print(f"B64 cleartext={cleartext_as_base64}")
    ciphered_b64 = encrypt(public_key, cleartext_as_base64)
    print(f"B64 ciphered={ciphered_b64}")
    res = get_parity(ciphered_b64)
    print(f"{res}")