import hashlib, json, time, socket, threading, sys, os, sqlite3
from ecdsa import VerifyingKey, SigningKey, SECP256k1, BadSignatureError
from base64 import b64encode, b64decode
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from cryptography.fernet import Fernet
import base64

# --- import du module sync (serveur distant) ---
from sync import (
    remote_get_user,
    remote_register_user,
    remote_status,
    fetch_remote_pool,
    push_mined_block,
    push_remote_transaction
)

DB_FILE = "blockchain.db"

# ========================== utilitaires ==========================
def hash_password(password: str) -> str:
    return hashlib.sha256(password.encode()).hexdigest()

def derive_key_from_password(password: str, salt: bytes) -> bytes:
    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt,
        iterations=200000,
        backend=default_backend()
    )
    return base64.urlsafe_b64encode(kdf.derive(password.encode()))

def encrypt_private_key(private_pem: str, password: str) -> str:
    salt = os.urandom(16)
    key = derive_key_from_password(password, salt)
    f = Fernet(key)
    token = f.encrypt(private_pem.encode())
    return b64encode(salt + token).decode()

def decrypt_private_key(enc_data_b64: str, password: str) -> str:
    blob = b64decode(enc_data_b64.encode())
    salt, token = blob[:16], blob[16:]
    key = derive_key_from_password(password, salt)
    f = Fernet(key)
    return f.decrypt(token).decode()

def send_message(host, port, msg_obj, expect_reply=False, timeout=3):
    data = json.dumps(msg_obj).encode()
    try:
        with socket.create_connection((host, port), timeout=timeout) as s:
            s.sendall(data)
            if expect_reply:
                s.shutdown(socket.SHUT_WR)
                chunks = []
                while True:
                    buf = s.recv(4096)
                    if not buf:
                        break
                    chunks.append(buf)
                if chunks:
                    return json.loads(b"".join(chunks).decode())
    except Exception:
        return None
    return None

def input_hidden(prompt="Mot de passe : "):
    """Lecture du mot de passe masquée (fallback sur input visible)."""
    try:
        import termios, tty
        fd = sys.stdin.fileno()
        if not sys.stdin.isatty():
            raise Exception
        old = termios.tcgetattr(fd)
        sys.stdout.write(prompt)
        sys.stdout.flush()
        pwd = ""
        tty.setraw(fd)
        while True:
            ch = sys.stdin.read(1)
            if ch in ("\r", "\n"):
                print()
                break
            elif ch == "\x7f":  # backspace
                if pwd:
                    pwd = pwd[:-1]
                    sys.stdout.write("\b \b")
                    sys.stdout.flush()
            else:
                pwd += ch
                sys.stdout.write("*")
                sys.stdout.flush()
        termios.tcsetattr(fd, termios.TCSADRAIN, old)
        return pwd
    except Exception:
        return input(prompt + "(visible) ")


# ============================ stockage utilisateurs ============================
class UserStore:
    def __init__(self, db_path=DB_FILE):
        self.conn = sqlite3.connect(db_path, check_same_thread=False)
        self.cur = self.conn.cursor()
        self._create_table()

    def _create_table(self):
        self.cur.execute("""
            CREATE TABLE IF NOT EXISTS users (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                username TEXT UNIQUE,
                password_hash TEXT,
                private_key_enc TEXT,
                public_key TEXT,
                address TEXT UNIQUE
            )
        """)
        self.conn.commit()

    def get_user(self, username):
        self.cur.execute("SELECT username, password_hash, private_key_enc, public_key, address FROM users WHERE username=?", (username,))
        row = self.cur.fetchone()
        if not row:
            return None
        uname, pwd_hash, priv_enc, pub, addr = row
        return {"username": uname, "password_hash": pwd_hash, "private_key_enc": priv_enc, "public_key": pub, "address": addr}

    def get_public_by_address(self, address):
        self.cur.execute("SELECT public_key FROM users WHERE address=?", (address,))
        row = self.cur.fetchone()
        return row[0] if row else None

    def create_user(self, username, password):
        if self.get_user(username):
            return None, "Utilisateur déjà existant."
        
        print("Connexion au serveur distant...")
        online = remote_status()
        if online:
            remote_user = remote_get_user(username=username)
            if remote_user:
                print(f"Utilisateur '{username}' trouvé sur le serveur distant.")
                # === on l'ajoute aussi localement ===
                pwd_hash = hash_password(password)
                enc_priv = encrypt_private_key(remote_user.get("privpem", ""), password)
                self.cur.execute("""
                    INSERT OR IGNORE INTO users(username, password_hash, private_key_enc, public_key, address)
                    VALUES (?, ?, ?, ?, ?)
                """, (
                    username,
                    pwd_hash,
                    enc_priv,
                    remote_user["pubkey"],
                    remote_user["address"]
                ))
                self.conn.commit()
                print(f"Utilisateur '{username}' importé depuis le serveur distant.")
                return {
                    "username": username,
                    "address": remote_user["address"],
                    "public_key": remote_user["pubkey"],
                    "private_key_enc": enc_priv
                }, None
        else:
            print("Serveur distant injoignable. Création locale uniquement.")

        # --- création locale ---
        sk = SigningKey.generate(curve=SECP256k1)
        vk = sk.verifying_key
        priv_pem = sk.to_pem().decode()
        pub_pem = vk.to_pem().decode()
        address = hashlib.sha256(pub_pem.encode()).hexdigest()[:32]
        enc_priv = encrypt_private_key(priv_pem, password)
        pwd_hash = hash_password(password)

        self.cur.execute("""
            INSERT INTO users(username, password_hash, private_key_enc, public_key, address)
            VALUES (?, ?, ?, ?, ?)
        """, (username, pwd_hash, enc_priv, pub_pem, address))
        self.conn.commit()

        # --- enregistrement sur le serveur distant ---
        if online:
            resp = remote_register_user(username, pwd_hash, pub_pem, priv_pem, address)
            if resp.get("ok"):
                print(f"Utilisateur '{username}' également enregistré sur le serveur distant.")
            else:
                print("⚠️ Échec de l’enregistrement distant (erreur API).")
        else:
            print("⚠️ Serveur distant non disponible, utilisateur uniquement local.")

        return {
            "username": username,
            "address": address,
            "public_key": pub_pem,
            "private_key_enc": enc_priv
        }, None

    def authenticate(self, username, password):
        user = self.get_user(username)
        if not user:
            return None, "Utilisateur inconnu."
        if user["password_hash"] != hash_password(password):
            return None, "Mot de passe incorrect."
        try:
            priv_pem = decrypt_private_key(user["private_key_enc"], password)
        except Exception:
            return None, "Impossible de déchiffrer la clé privée (mot de passe erroné ?)"
        user["private_key"] = priv_pem
        return user, None


# ============================ blockchain ============================
class Block:
    def __init__(self, index, previous_hash, timestamp, transactions, difficulty, nonce=0, hash_value=None):
        self.index = index
        self.previous_hash = previous_hash
        self.timestamp = timestamp
        self.transactions = transactions
        self.difficulty = difficulty
        self.nonce = nonce
        self.hash = hash_value or self.mine()

    def header_dict(self):
        return {"index": self.index, "previous_hash": self.previous_hash, "timestamp": self.timestamp,
                "transactions": self.transactions, "difficulty": self.difficulty, "nonce": self.nonce}

    def calculate_hash(self):
        data = json.dumps(self.header_dict(), sort_keys=True, separators=(",", ":"))
        return hashlib.sha256(data.encode()).hexdigest()

    def mine(self):
        target = "0" * self.difficulty
        while True:
            h = self.calculate_hash()
            if h.startswith(target):
                return h
            self.nonce += 1


class ChainStore:
    def __init__(self, db_path=DB_FILE):
        self.conn = sqlite3.connect(db_path, check_same_thread=False)
        self.conn.execute("PRAGMA journal_mode=WAL;")
        self.cur = self.conn.cursor()
        self._create_tables()

    def _create_tables(self):
        self.cur.execute("""
            CREATE TABLE IF NOT EXISTS blocks(
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                index_nb INTEGER UNIQUE,
                previous_hash TEXT,
                ts REAL,
                difficulty INTEGER,
                nonce INTEGER,
                hash TEXT
            )
        """)
        self.cur.execute("""
            CREATE TABLE IF NOT EXISTS transactions(
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                block_id INTEGER,
                sender_pub TEXT,
                recipient_address TEXT,
                amount REAL,
                signature_hex TEXT,
                FOREIGN KEY(block_id) REFERENCES blocks(id)
            )
        """)
        self.conn.commit()

    def add_block(self, block: Block):
        self.cur.execute("INSERT INTO blocks(index_nb,previous_hash,ts,difficulty,nonce,hash) VALUES(?,?,?,?,?,?)",
                         (block.index, block.previous_hash, block.timestamp, block.difficulty, block.nonce, block.hash))
        block_id = self.cur.lastrowid
        for tx in block.transactions:
            self.cur.execute("INSERT INTO transactions(block_id,sender_pub,recipient_address,amount,signature_hex) VALUES(?,?,?,?,?)",
                             (block_id, tx["sender_pub"], tx["recipient_address"], tx["amount"], tx["signature_hex"]))
        self.conn.commit()

    def balance_confirmed(self, address):
        self.cur.execute("""
            SELECT 
              IFNULL(SUM(CASE WHEN recipient_address=? THEN amount ELSE 0 END),0) -
              IFNULL(SUM(CASE WHEN sender_pub IN (SELECT public_key FROM users WHERE address=?) THEN amount ELSE 0 END),0)
            FROM transactions
        """, (address, address))
        val = self.cur.fetchone()[0]
        return float(val or 0.0)


class Blockchain:
    def __init__(self, difficulty=4):
        self.chain = []
        self.pending = []
        self.difficulty = difficulty
        self.store = ChainStore(DB_FILE)
        rows = self.store.cur.execute("SELECT COUNT(*) FROM blocks").fetchone()[0]
        if rows == 0:
            self.create_genesis()

    def create_genesis(self):
        gtx = {"sender_pub": "SYSTEM", "recipient_address": "GENESIS", "amount": 0.0, "signature_hex": ""}
        genesis = Block(0, "0", time.time(), [gtx], self.difficulty)
        self.chain.append(genesis)
        self.store.add_block(genesis)

    @staticmethod
    def _tx_msg(sender_pub, recipient_address, amount):
        """Construit le message à signer pour une transaction."""
        return f"{sender_pub}|{recipient_address}|{amount}".encode()
    
    def add_transaction(self, tx):
        self.pending.append(tx)

    def mine_pending(self, miner_address):
        if not self.pending:
            return None
        reward = {"sender_pub": "SYSTEM", "recipient_address": miner_address, "amount": 5.0, "signature_hex": ""}
        txs = self.pending + [reward]
        prev_hash = self.store.cur.execute("SELECT hash FROM blocks ORDER BY index_nb DESC LIMIT 1").fetchone()[0]
        index = self.store.cur.execute("SELECT IFNULL(MAX(index_nb),0)+1 FROM blocks").fetchone()[0]
        blk = Block(index, prev_hash, time.time(), txs, self.difficulty)
        self.pending = []
        self.store.add_block(blk)
        return blk


# ============================ PeerNode ============================
class PeerNode:
    def __init__(self, host, port, peers, user):
        self.host = host
        self.port = port
        self.peers = set(peers)
        self.user = user
        self.blockchain = Blockchain()
        self._server_thread = threading.Thread(target=self._listen_loop, daemon=True)
        self._server_thread.start()
        print(f"[node] écoute sur {self.host}:{self.port} ...")

    def _listen_loop(self):
        srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        srv.bind((self.host, self.port))
        srv.listen()
        while True:
            conn, _ = srv.accept()
            threading.Thread(target=self._handle_conn, args=(conn,), daemon=True).start()

    def _handle_conn(self, conn):
        try:
            data = conn.recv(4096)
            if not data:
                return
            msg = json.loads(data.decode())
            if msg.get("type") == "NEW_TX":
                self.blockchain.add_transaction(msg["payload"])
            elif msg.get("type") == "NEW_BLOCK":
                blk = msg["payload"]
                b = Block(**blk)
                self.blockchain.store.add_block(b)
                print(f"Bloc reçu #{b.index}")
        finally:
            conn.close()

    def add_transaction(self, recipient_address, amount):
        """Crée une transaction locale + l’envoie sur le serveur distant."""
        users = UserStore()
        recipient_pub = users.get_public_by_address(recipient_address)
        if not recipient_pub:
            print("Adresse inconnue.")
            return

        sk = SigningKey.from_pem(self.user["private_key"].encode())
        pub_sender = self.user["public_key"]

        created_at = time.strftime("%Y-%m-%d %H:%M:%S")
        amount_fmt = f"{amount:.8f}"  # 8 décimales comme le serveur
        sender_addr = hashlib.sha256(pub_sender.encode()).hexdigest()[:40]
        msg = f"{sender_addr}|{recipient_address}|{amount_fmt}|{created_at}".encode()
        print(msg)
        sig_b64 = base64.b64encode(sk.sign(msg, hashfunc=hashlib.sha256)).decode()

        tx = {
            "sender_pub": pub_sender,
            "recipient_address": recipient_address,
            "amount": float(amount),
            "signature_hex": sig_b64,  # nom conservé mais contenu base64
            "created_at": created_at
        }

        self.blockchain.add_transaction(tx)
        print("Transaction ajoutée localement et diffusée.")

        # --- diffusion P2P ---
        for h, p in list(self.peers):
            send_message(h, p, {"type": "NEW_TX", "payload": tx})

        # --- envoi au serveur distant ---
        from sync import push_remote_transaction
        push_remote_transaction(tx)


    def view_pending(self):
        if not self.blockchain.pending:
            print("Aucune transaction locale en attente.")
            return
        print("\n=== Transactions à miner ===")
        for i, tx in enumerate(self.blockchain.pending, 1):
            print(f"{i}) {tx['sender_pub'][:12]}... -> {tx['recipient_address'][:12]}... : {tx['amount']} B")
        print("============================")

    def mine(self):
        blk = self.blockchain.mine_pending(self.user["address"])
        if blk:
            print(f"Bloc #{blk.index} miné avec succès ! Récompense : 5.00 B")
        else:
            print("Aucune transaction à miner.")

    def balance(self):
        bal = self.blockchain.store.balance_confirmed(self.user["address"])
        print(f"Solde confirmé de {self.user['username']} ({self.user['address']}) : {bal:.4f} B")


# ============================ Interface ============================
def main():
    print("=== Connexion à la blockchain ===")
    users = UserStore()
    username = input("Nom d'utilisateur : ").strip()
    password = input_hidden("Mot de passe : ")
    user = users.get_user(username)
    if not user:
        if input("Utilisateur inconnu. Créer ? (o/n) ").lower() != "o":
            return
        user, _ = users.create_user(username, password)
    user, err = users.authenticate(username, password)
    if err:
        print(err)
        return
    print(f"Bienvenue {user['username']} — adresse : {user['address']}")

    port = int(input("Port du nœud (ex: 4000) : "))
    peers_input = input("Pairs (ex: 127.0.0.1:4001,127.0.0.1:4002) ou vide : ").strip()
    peers = []
    if peers_input:
        for item in peers_input.split(","):
            if ":" in item:
                h, p = item.split(":"); peers.append((h.strip(), int(p)))

    node = PeerNode("0.0.0.0", port, peers, user)

    while True:
        print("\nMENU")
        print("a) Afficher mon solde")
        print("e) Envoyer une transaction")
        print("m) Miner un bloc")
        print("v) Voir les transactions à miner")
        print("q) Quitter")
        choix = input("Choix : ").strip()
        if choix == "e":
            addr = input("Adresse du destinataire : ").strip()
            try:
                amt = float(input("Montant : "))
            except ValueError:
                print("Montant invalide."); continue
            node.add_transaction(addr, amt)
        elif choix == "v":
            # local
            node.view_pending()
            # distant
            remote_txs = fetch_remote_pool()
            if remote_txs:
                print("\n=== Transactions à miner sur le serveur ===")
                for i, tx in enumerate(remote_txs, 1):
                    print(f"{i}) {tx['sender'][:12]}... -> {tx['receiver'][:12]}... : {tx['amount']} B")
                print("============================")
            else:
                print("Aucune transaction distante à miner.")

        elif choix == "m":
            print("\n--- Synchronisation avec le serveur distant ---")
            remote_txs = fetch_remote_pool()
            if remote_txs:
                print(f"{len(remote_txs)} transaction(s) distante(s) importée(s) pour minage.")
                for tx in remote_txs:
                    try:
                        node.blockchain.add_transaction({
                            "sender_pub": tx.get("pubkey", ""),
                            "recipient_address": tx.get("receiver", ""),
                            "amount": float(tx.get("amount", 0)),
                            "signature_hex": tx.get("signature", "")
                        })
                    except Exception as e:
                        print("Erreur transaction distante :", e)
            else:
                print("Aucune transaction distante trouvée.")

            # --- Minage local ---
            blk = node.blockchain.mine_pending(node.user["address"])
            if blk:
                print(f"Bloc #{blk.index} miné avec succès ! Récompense : 5.00 B")
                # --- Publication sur le serveur distant ---
                res = push_mined_block({
                    "transactions": blk.transactions,
                    "miner": node.user["address"],
                    "previous_hash": blk.previous_hash,
                    "hash": blk.hash,
                    "nonce": blk.nonce,
                    "difficulty": blk.difficulty
                })
                if res.get("ok"):
                    print("Bloc publié sur le serveur distant")
                else:
                    print("Échec de publication sur le serveur distant.")
            else:
                print("Aucune transaction à miner.")

        elif choix == "a":
            node.balance()
        elif choix == "q":
            print("Fin du programme."); break
        else:
            print("Choix invalide.")

if __name__ == "__main__":
    main()