#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sqlite3
import os
import requests
from datetime import datetime
from scapy.all import rdpcap, IP, TCP, UDP, Raw

DB_PATH = r"C:\Temp\projectStarlink\DB\wireshark.db"
HTML_PATH = r"C:\Temp\projectStarlink\HTML\wireshark.html"

# Remote PCAPNG source
PCAP_URL = "https://dealejandro.belliniseven.ai/PCAPNG/loopbackCapture.pcapng"

# Local temp storage for downloaded PCAPNG
PCAP_LOCAL = r"C:\Temp\projectStarlink\loopbackCapture_remote.pcapng"


def download_pcap():
    os.makedirs(os.path.dirname(PCAP_LOCAL), exist_ok=True)
    print("Downloading PCAPNG from remote source...")
    response = requests.get(PCAP_URL, timeout=30)
    response.raise_for_status()

    with open(PCAP_LOCAL, "wb") as f:
        f.write(response.content)

    print(f"Downloaded remote PCAPNG to: {PCAP_LOCAL}")


def init_db():
    os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()

    cur.execute("""
        CREATE TABLE IF NOT EXISTS packets (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp TEXT,
            protocol TEXT,
            src_ip TEXT,
            dst_ip TEXT,
            src_port INTEGER,
            dst_port INTEGER,
            payload_hex TEXT,
            payload_ascii TEXT
        )
    """)

    conn.commit()
    return conn


def insert_packet(conn, entry):
    cur = conn.cursor()
    cur.execute("""
        INSERT INTO packets (timestamp, protocol, src_ip, dst_ip, src_port, dst_port, payload_hex, payload_ascii)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?)
    """, (
        entry["timestamp"],
        entry["protocol"],
        entry["src_ip"],
        entry["dst_ip"],
        entry["src_port"],
        entry["dst_port"],
        entry["payload_hex"],
        entry["payload_ascii"]
    ))
    conn.commit()


def parse_pcap_to_db():
    packets = rdpcap(PCAP_LOCAL)
    conn = init_db()

    for pkt in packets:
        entry = {
            "timestamp": None,
            "protocol": None,
            "src_ip": None,
            "dst_ip": None,
            "src_port": None,
            "dst_port": None,
            "payload_hex": None,
            "payload_ascii": None
        }

        if hasattr(pkt, "time"):
            try:
                entry["timestamp"] = datetime.fromtimestamp(float(pkt.time)).isoformat()
            except Exception:
                entry["timestamp"] = None

        if IP in pkt:
            entry["src_ip"] = pkt[IP].src
            entry["dst_ip"] = pkt[IP].dst

        if TCP in pkt:
            entry["protocol"] = "TCP"
            entry["src_port"] = pkt[TCP].sport
            entry["dst_port"] = pkt[TCP].dport
        elif UDP in pkt:
            entry["protocol"] = "UDP"
            entry["src_port"] = pkt[UDP].sport
            entry["dst_port"] = pkt[UDP].dport

        if Raw in pkt:
            raw_bytes = bytes(pkt[Raw].load)
            entry["payload_hex"] = raw_bytes.hex()
            entry["payload_ascii"] = raw_bytes.decode("utf-8", errors="replace")

        insert_packet(conn, entry)

    conn.close()


def get_stats():
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()

    cur.execute("SELECT COUNT(*) FROM packets")
    total_packets = cur.fetchone()[0] or 0

    cur.execute("SELECT protocol, COUNT(*) FROM packets GROUP BY protocol")
    proto_counts = cur.fetchall()

    cur.execute("""
        SELECT src_ip, COUNT(*) AS c
        FROM packets
        WHERE src_ip IS NOT NULL
        GROUP BY src_ip
        ORDER BY c DESC
        LIMIT 5
    """)
    top_src = cur.fetchall()

    cur.execute("""
        SELECT dst_ip, COUNT(*) AS c
        FROM packets
        WHERE dst_ip IS NOT NULL
        GROUP BY dst_ip
        ORDER BY c DESC
        LIMIT 5
    """)
    top_dst = cur.fetchall()

    conn.close()
    return total_packets, proto_counts, top_src, top_dst


def generate_html():
    os.makedirs(os.path.dirname(HTML_PATH), exist_ok=True)

    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute("SELECT id, timestamp, protocol, src_ip, dst_ip, src_port, dst_port, payload_hex, payload_ascii FROM packets ORDER BY id ASC")
    rows = cur.fetchall()
    conn.close()

    total_packets, proto_counts, top_src, top_dst = get_stats()

    html = []

    # HTML HEADER
    html.append(
        "<html><head>"
        "<title>Wireshark Loopback Address Capture Report</title>"
        "<meta http-equiv='refresh' content='5'>"
        "<style>"
        "body { font-family: Arial, sans-serif; margin: 10px; }"
        "h2 { margin-bottom: 5px; }"
        ".summary { margin-bottom: 15px; font-size: 13px; }"
        ".summary-table { border-collapse: collapse; margin-top: 5px; }"
        ".summary-table th, .summary-table td { border: 1px solid #ccc; padding: 4px 6px; font-size: 12px; }"
        "table { border-collapse: collapse; width: 100%; margin-top: 10px; }"
        "th, td { border: 1px solid #ccc; padding: 6px; font-size: 12px; }"
        "th { background-color: #f2f2f2; }"
        "#filterBox { margin-bottom: 10px; padding: 6px; width: 40%; }"
        "tr:hover { background-color: #f9f9f9; cursor: pointer; }"
        "#modalOverlay { display:none; position:fixed; z-index:9999; left:0; top:0; width:100%; height:100%; background-color:rgba(0,0,0,0.5); }"
        "#modalContent { background-color:#fff; margin:5% auto; padding:15px; border:1px solid #888; width:80%; max-height:80%; overflow:auto; font-size:12px; }"
        "#modalClose { float:right; cursor:pointer; font-weight:bold; }"
        "pre { white-space:pre-wrap; word-wrap:break-word; font-family:Consolas, monospace; font-size:11px; }"
        "</style>"
        "<script>"
        "function filterTable(){let i=document.getElementById('filterBox').value.toLowerCase();"
        "let r=document.querySelectorAll('#packetTable tr');"
        "r.forEach((row,x)=>{if(x===0)return;let t=row.innerText.toLowerCase();row.style.display=t.includes(i)?'':'none';});}"
        "function showModal(h,a){document.getElementById('modalHex').innerText=h||'';"
        "document.getElementById('modalAscii').innerText=a||'';"
        "document.getElementById('modalOverlay').style.display='block';}"
        "function closeModal(){document.getElementById('modalOverlay').style.display='none';}"
        "</script>"
        "</head><body>"
    )

    # TITLE
    html.append("<h2>Wireshark Loopback Address Capture Report</h2>")

    # SUMMARY SECTION
    html.append("<div class='summary'>")
    html.append(f"<div><strong>Total Packets:</strong> {total_packets}</div>")

    html.append("<div style='margin-top:6px;'><strong>Protocol Counts:</strong></div>")
    html.append("<table class='summary-table'><tr><th>Protocol</th><th>Count</th></tr>")
    for proto, count in proto_counts:
        html.append(f"<tr><td>{proto}</td><td>{count}</td></tr>")
    html.append("</table>")

    html.append("<div style='margin-top:6px;'><strong>Top Source IPs:</strong></div>")
    html.append("<table class='summary-table'><tr><th>Source IP</th><th>Count</th></tr>")
    for ip, count in top_src:
        html.append(f"<tr><td>{ip}</td><td>{count}</td></tr>")
    html.append("</table>")

    html.append("<div style='margin-top:6px;'><strong>Top Destination IPs:</strong></div>")
    html.append("<table class='summary-table'><tr><th>Destination IP</th><th>Count</th></tr>")
    for ip, count in top_dst:
        html.append(f"<tr><td>{ip}</td><td>{count}</td></tr>")
    html.append("</table>")

    html.append("</div>")  # end summary

    # FILTER + TABLE HEADER
    html.append("<input id='filterBox' type='text' onkeyup='filterTable()' placeholder='Filter packets...'>")
    html.append(
        "<table id='packetTable'>"
        "<tr>"
        "<th>ID</th>"
        "<th>Timestamp</th>"
        "<th>Protocol</th>"
        "<th>Src IP</th>"
        "<th>Dst IP</th>"
        "<th>Src Port</th>"
        "<th>Dst Port</th>"
        "<th>Payload (Preview)</th>"
        "</tr>"
    )

    # TABLE ROWS
    for row in rows:
        (pid, ts, proto, src_ip, dst_ip, sport, dport, payload_hex, payload_ascii) = row
        preview = (payload_ascii or "")[:40].replace("<", "&lt;").replace(">", "&gt;")
        safe_hex = (payload_hex or "").replace("\\", "\\\\").replace("'", "\\'")
        safe_ascii = (payload_ascii or "").replace("\\", "\\\\").replace("'", "\\'")

        html.append(
            "<tr onclick=\"showModal('"
            + safe_hex
            + "', '"
            + safe_ascii
            + "')\">"
            f"<td>{pid}</td>"
            f"<td>{ts}</td>"
            f"<td>{proto}</td>"
            f"<td>{src_ip}</td>"
            f"<td>{dst_ip}</td>"
            f"<td>{sport}</td>"
            f"<td>{dport}</td>"
            f"<td>{preview}</td>"
            "</tr>"
        )

    html.append("</table>")

    # MODAL
    html.append(
        "<div id='modalOverlay' onclick='closeModal()'>"
        "<div id='modalContent' onclick='event.stopPropagation();'>"
        "<span id='modalClose' onclick='closeModal()'>[X]</span>"
        "<h3>Packet Payload</h3>"
        "<h4>Hex</h4>"
        "<pre id='modalHex'></pre>"
        "<h4>ASCII</h4>"
        "<pre id='modalAscii'></pre>"
        "</div>"
        "</div>"
    )

    html.append("</body></html>")

    with open(HTML_PATH, "w", encoding="utf-8") as f:
        f.write("\n".join(html))


def main():
    download_pcap()
    parse_pcap_to_db()
    generate_html()
    print(f"Remote PCAPNG downloaded to: {PCAP_LOCAL}")
    print(f"Real-time DB updated at: {DB_PATH}")
    print(f"HTML dashboard generated at: {HTML_PATH}")


if __name__ == "__main__":
    main()
