#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: 2025 Peter Lemenkov <lemenkov@gmail.com>
# SPDX-License-Identifier: MIT
#
"""
Round-Robin HTTP Proxy with Health Monitoring
Polls backend servers every minute and routes to the healthiest one.
"""

import aiohttp
import asyncio
import json
import logging
import os
import signal
import time
from aiohttp import web, ClientTimeout
from dataclasses import dataclass, field
from typing import List, Optional

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Reduce aiohttp access log verbosity
logging.getLogger("aiohttp.access").setLevel(logging.WARNING)


@dataclass
class ServerHealth:
    url: str
    is_healthy: bool = False
    last_check: float = 0
    response_time: float = float("inf")
    consecutive_failures: int = 0
    last_error: Optional[str] = None
    client_version: Optional[str] = None


@dataclass
class ProxyConfig:
    servers: List[str] = field(default_factory=list)
    reference_addr: str = "localhost"
    server_key: str = "<get one please>"
    health_check_interval: int = 60  # seconds
    health_check_timeout: int = 5  # seconds
    proxy_host: str = "0.0.0.0"  # Public interface for client connections
    proxy_port: int = 8080
    management_host: str = "127.0.0.1"  # Management interface (localhost only)
    management_port: int = 8081
    max_consecutive_failures: int = 3
    max_request_size: int = 10 * 1024 * 1024  # 10MB default
    eth_getLogs_limit: int = 75  # Max concurrent eth_getLogs requests
    drain_delay: int = (
        40  # seconds to keep serving after SIGTERM while the LB drains us
    )


class HealthMonitor:
    def __init__(self, config: ProxyConfig):
        self.config = config
        self.servers = [ServerHealth(url=url) for url in config.servers]
        self.session: Optional[aiohttp.ClientSession] = None
        self.current_server_index = 0
        self.current_block: Optional[int] = None

    async def start(self):
        """Initialize the health monitor"""
        timeout = ClientTimeout(total=self.config.health_check_timeout)
        self.session = aiohttp.ClientSession(timeout=timeout)

        # Start health checking task
        asyncio.create_task(self.health_check_loop())
        logger.info(f"Health monitor started for servers: {self.config.servers}")

    async def stop(self):
        """Clean up resources"""
        if self.session:
            await self.session.close()

    async def health_check_loop(self):
        """Continuously monitor server health"""
        while True:
            await self.check_all_servers()
            await asyncio.sleep(self.config.health_check_interval)

    async def get_reference_block(self) -> Optional[int]:
        """Fetch current block number from reference server"""
        payload = {"jsonrpc": "2.0", "method": "eth_blockNumber", "params": [], "id": 1}
        main_url = f"https://{self.config.reference_addr}/ogrpc?network=ethereum&dkey={self.config.server_key}"

        try:
            logger.info(f"Checking {main_url}")
            async with self.session.post(main_url, json=payload) as response:
                if response.status == 200:
                    result = await response.json()
                    return int(result["result"], 16)
        except Exception as e:
            logger.error(f"Failed to get reference block: {e}")
        return None

    async def fetch_client_version(self, server: ServerHealth):
        """Fetch web3_clientVersion from a server"""
        payload = {
            "jsonrpc": "2.0",
            "method": "web3_clientVersion",
            "params": [],
            "id": 1,
        }

        try:
            async with self.session.post(server.url, json=payload) as response:
                if response.status == 200:
                    result = await response.json()
                    server.client_version = result.get("result")
                    logger.info(f"Server {server.url} version: {server.client_version}")
        except Exception as e:
            logger.warning(f"Failed to get client version from {server.url}: {e}")

    async def check_all_servers(self):
        """Check health of all servers concurrently"""

        current_block = await self.get_reference_block()
        if current_block is None:
            return

        self.current_block = current_block

        tasks = []
        for server in self.servers:
            task = asyncio.create_task(self.check_server_health(server, current_block))
            tasks.append(task)

        await asyncio.gather(*tasks, return_exceptions=True)
        self.log_server_status()

    async def check_server_health(self, server: ServerHealth, current_block: int):
        """Check health of a single server"""
        payload = {"jsonrpc": "2.0", "method": "eth_blockNumber", "params": [], "id": 1}
        start_time = time.time()
        was_healthy = server.is_healthy

        try:
            async with self.session.post(server.url, json=payload) as response:
                response_time = time.time() - start_time

                if response.status == 200:
                    result = await response.json()
                    node_block = int(result["result"], 16)
                    if current_block - node_block > 5:  # Allow a small lag
                        self.mark_server_unhealthy(
                            server,
                            f"Node lagging (block {node_block}, current {current_block})",
                            immediate=True,
                        )
                        return
                    server.is_healthy = True
                    server.response_time = response_time
                    server.consecutive_failures = 0
                    server.last_error = None

                    # Fetch version on first healthy check or re-join
                    if not was_healthy or server.client_version is None:
                        await self.fetch_client_version(server)

                    logger.debug(
                        f"Server {server.url} is healthy (response time: {response_time:.3f}s)"
                    )
                else:
                    self.mark_server_unhealthy(server, f"HTTP {response.status}")

        except Exception as e:
            self.mark_server_unhealthy(server, str(e))

        server.last_check = time.time()

    async def check_single_server(self, server: ServerHealth):
        """Check health of a single server (fetches reference block automatically)

        Used for immediate health check when adding a new server.
        """

        # Use cached current_block if available, otherwise fetch it
        current_block = self.current_block
        if current_block is None:
            current_block = await self.get_reference_block()

        if current_block is None:
            self.mark_server_unhealthy(server, "Could not get reference block")
            return

        await self.check_server_health(server, current_block)

    def mark_server_unhealthy(
        self, server: ServerHealth, error: str, immediate: bool = False
    ):
        """Mark server as unhealthy and track failures

        Args:
            server: The server to mark unhealthy
            error: Error message
            immediate: If True, mark unhealthy immediately without waiting for threshold
        """
        server.consecutive_failures += 1
        server.last_error = error
        server.response_time = float("inf")

        if (
            immediate
            or server.consecutive_failures >= self.config.max_consecutive_failures
        ):
            if server.is_healthy:
                logger.warning(f"Server {server.url} marked as unhealthy: {error}")
            server.is_healthy = False

    def log_server_status(self):
        """Log current status of all servers"""
        healthy_count = sum(1 for s in self.servers if s.is_healthy)
        logger.info(
            f"Health check complete: {healthy_count}/{len(self.servers)} servers healthy"
        )

        for server in self.servers:
            status = "✓" if server.is_healthy else "✗"
            rt = (
                f"{server.response_time:.3f}s"
                if server.response_time != float("inf")
                else "N/A"
            )
            version = f" [{server.client_version}]" if server.client_version else ""
            error_info = f" ({server.last_error})" if server.last_error else ""
            logger.info(
                f"  {status} {server.url}{version} - Response time: {rt}{error_info}"
            )


class HTTPProxy:
    def __init__(self, config: ProxyConfig):
        self.config = config
        self.health_monitor = HealthMonitor(config)
        self.client_session: Optional[aiohttp.ClientSession] = None
        self.eth_getLogs_semaphore = asyncio.Semaphore(config.eth_getLogs_limit)
        self.draining = False
        self._shutdown_event: Optional[asyncio.Event] = None

    async def start_server(self):
        """Start both the proxy server and management server"""
        await self.health_monitor.start()
        self._shutdown_event = asyncio.Event()

        # Create client session for proxying requests
        self.client_session = aiohttp.ClientSession()

        # Increase max payload size to 10MB (default is 2MB)
        proxy_app = web.Application(client_max_size=self.config.max_request_size)
        proxy_app.router.add_get("/_health", self.lb_health_handler)
        proxy_app.router.add_route("*", "/{path:.*}", self.proxy_handler)

        # Private management application (localhost only)
        mgmt_app = web.Application()
        mgmt_app.router.add_get("/health", self.health_status_handler)
        mgmt_app.router.add_post("/loglevel", self.set_loglevel_handler)
        mgmt_app.router.add_post("/servers/add", self.add_server_handler)
        mgmt_app.router.add_post("/servers/remove", self.remove_server_handler)

        # Start proxy server on public interface
        proxy_runner = web.AppRunner(proxy_app)
        await proxy_runner.setup()
        proxy_site = web.TCPSite(
            proxy_runner, self.config.proxy_host, self.config.proxy_port
        )
        await proxy_site.start()
        logger.info(
            f"Proxy server started on http://{self.config.proxy_host}:{self.config.proxy_port}"
        )
        logger.info(
            f"eth_getLogs rate limit: {self.config.eth_getLogs_limit} concurrent requests"
        )

        # Start management server on localhost only
        mgmt_runner = web.AppRunner(mgmt_app)
        await mgmt_runner.setup()
        mgmt_site = web.TCPSite(
            mgmt_runner, self.config.management_host, self.config.management_port
        )
        await mgmt_site.start()
        logger.info(
            f"Management server started on http://{self.config.management_host}:{self.config.management_port}"
        )

        return (proxy_runner, mgmt_runner)

    async def lb_health_handler(self, request: web.Request) -> web.Response:
        """Liveness/readiness probe for the load balancer.

        Returns 503 while draining, or when no backend is healthy; 200 otherwise.
        Read-only: exposes only a count, safe to serve on the public port.
        """
        if self.draining:
            return web.json_response(
                {
                    "status": "draining",
                    "healthy_backends": 0,
                    "total_backends": len(self.health_monitor.servers),
                },
                status=503,
            )
        healthy_count = sum(1 for s in self.health_monitor.servers if s.is_healthy)
        return web.json_response(
            {
                "healthy_backends": healthy_count,
                "total_backends": len(self.health_monitor.servers),
            },
            status=200 if healthy_count > 0 else 503,
        )

    async def _drain_and_shutdown(self):
        """Flip the probe to 503, keep serving for drain_delay, then signal exit."""
        if self.draining:
            return
        self.draining = True
        logger.info(
            f"Shutdown signal received: /_health now reports 503. "
            f"Serving for {self.config.drain_delay}s while the load balancer stops routing..."
        )
        await asyncio.sleep(self.config.drain_delay)
        logger.info("Drain window elapsed; finishing in-flight requests and exiting.")
        self._shutdown_event.set()

    async def health_status_handler(self, request: web.Request) -> web.Response:
        """Return current health status of all backend servers

        Example usage (management interface only):
            curl http://localhost:8081/health
        """
        servers_status = []
        for server in self.health_monitor.servers:
            servers_status.append(
                {
                    "url": server.url,
                    "healthy": server.is_healthy,
                    "response_time": (
                        server.response_time
                        if server.response_time != float("inf")
                        else None
                    ),
                    "last_check": server.last_check,
                    "consecutive_failures": server.consecutive_failures,
                    "last_error": server.last_error,
                    "client_version": server.client_version,
                }
            )

        return web.json_response(
            {
                "servers": servers_status,
                "healthy_count": sum(
                    1 for s in self.health_monitor.servers if s.is_healthy
                ),
                "total_count": len(self.health_monitor.servers),
                "eth_getLogs_limit": self.config.eth_getLogs_limit,
            }
        )

    async def set_loglevel_handler(self, request: web.Request) -> web.Response:
        """Dynamically change log level

        Example usage (management interface only):
            curl -X POST http://localhost:8081/loglevel -H "Content-Type: application/json" -d '{"level": "DEBUG"}'
            curl -X POST http://localhost:8081/loglevel -H "Content-Type: application/json" -d '{"level": "INFO", "logger": "aiohttp.access"}'
        """
        try:
            data = await request.json()
            level = data.get("level", "").upper()
            logger_name = data.get("logger", "__main__")  # Default to main logger

            # Validate level
            valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
            if level not in valid_levels:
                return web.json_response(
                    {
                        "error": f"Invalid level. Must be one of: {', '.join(valid_levels)}"
                    },
                    status=400,
                )

            # Get the logger and set level
            target_logger = logging.getLogger(logger_name)
            target_logger.setLevel(getattr(logging, level))

            logger.info(f"Log level for '{logger_name}' changed to {level}")

            return web.json_response(
                {"success": True, "logger": logger_name, "level": level}
            )

        except Exception as e:
            return web.json_response({"error": str(e)}, status=400)

    async def add_server_handler(self, request: web.Request) -> web.Response:
        """Add a new backend server to the pool

        Example usage (management interface only):
            curl -X POST http://localhost:8081/servers/add -H "Content-Type: application/json" -d '{"url": "http://172.31.28.136:8545"}'
        """
        try:
            data = await request.json()
            server_url = data.get("url", "").strip()

            if not server_url:
                return web.json_response(
                    {"error": "Missing 'url' parameter"}, status=400
                )

            # Check if server already exists
            if any(s.url == server_url for s in self.health_monitor.servers):
                return web.json_response(
                    {"error": f"Server {server_url} already exists"}, status=400
                )

            # Add new server
            new_server = ServerHealth(url=server_url)
            self.health_monitor.servers.append(new_server)

            # Trigger immediate health check for the new server
            asyncio.create_task(self.health_monitor.check_single_server(new_server))

            logger.info(f"Added new server: {server_url}")

            return web.json_response(
                {
                    "success": True,
                    "message": f"Server {server_url} added successfully",
                    "total_servers": len(self.health_monitor.servers),
                }
            )

        except Exception as e:
            return web.json_response({"error": str(e)}, status=400)

    async def remove_server_handler(self, request: web.Request) -> web.Response:
        """Remove a backend server from the pool

        Example usage (management interface only):
            curl -X POST http://localhost:8081/servers/remove -H "Content-Type: application/json" -d '{"url": "http://172.31.28.136:8545"}'
        """
        try:
            data = await request.json()
            server_url = data.get("url", "").strip()

            if not server_url:
                return web.json_response(
                    {"error": "Missing 'url' parameter"}, status=400
                )

            # Find and remove the server
            server_to_remove = None
            for server in self.health_monitor.servers:
                if server.url == server_url:
                    server_to_remove = server
                    break

            if not server_to_remove:
                return web.json_response(
                    {"error": f"Server {server_url} not found"}, status=404
                )

            # Check if this is the last server
            if len(self.health_monitor.servers) == 1:
                return web.json_response(
                    {"error": "Cannot remove the last server"}, status=400
                )

            self.health_monitor.servers.remove(server_to_remove)

            logger.info(f"Removed server: {server_url}")

            return web.json_response(
                {
                    "success": True,
                    "message": f"Server {server_url} removed successfully",
                    "total_servers": len(self.health_monitor.servers),
                }
            )

        except Exception as e:
            return web.json_response({"error": str(e)}, status=400)

    async def _forward_request(
        self,
        request: web.Request,
        target_server: ServerHealth,
        target_url: str,
        headers: dict,
        data: bytes,
    ) -> Optional[web.Response]:
        """Forward a request to a target server. Returns response or None on failure."""
        try:
            async with self.client_session.request(
                method=request.method,
                url=target_url,
                headers=headers,
                data=data,
                allow_redirects=False,
            ) as response:
                # Prepare response headers
                resp_headers = dict(response.headers)

                # Remove hop-by-hop headers AND content-encoding headers
                # because aiohttp automatically decompresses the body
                hop_by_hop = {
                    "connection",
                    "keep-alive",
                    "proxy-authenticate",
                    "proxy-authorization",
                    "te",
                    "trailers",
                    "upgrade",
                }
                headers_to_remove = hop_by_hop | {"content-encoding", "content-length"}
                resp_headers = {
                    k: v
                    for k, v in resp_headers.items()
                    if k.lower() not in headers_to_remove
                }

                # Handle streaming response properly
                if response.headers.get("Transfer-Encoding", "").lower() == "chunked":
                    # For chunked responses, stream the content
                    resp = web.StreamResponse(
                        status=response.status, headers=resp_headers
                    )
                    await resp.prepare(request)

                    async for chunk in response.content.iter_chunked(8192):
                        await resp.write(chunk)

                    await resp.write_eof()
                    return resp
                else:
                    # For regular responses, read the full body
                    body = await response.read()

                    # Ensure Content-Length is set correctly
                    resp_headers["Content-Length"] = str(len(body))

                    return web.Response(
                        status=response.status, headers=resp_headers, body=body
                    )

        except Exception as e:
            logger.error(f"Error proxying request to {target_url}: {e}")

            # Mark this server as potentially unhealthy
            self.health_monitor.mark_server_unhealthy(
                target_server, f"Connection error: {str(e)}"
            )

            return None

    async def proxy_handler(self, request: web.Request) -> web.Response:
        """Handle incoming requests and proxy them to a healthy backend"""
        path = request.path_qs

        # Try all healthy servers in order of preference
        healthy_servers = [s for s in self.health_monitor.servers if s.is_healthy]
        healthy_servers.sort(key=lambda s: s.response_time)

        if not healthy_servers:
            return web.Response(
                status=503,
                text="Service Unavailable: No healthy backend servers",
                content_type="text/plain",
            )

        data = await request.read()
        rpc_method = None
        rpc_params = None
        try:
            parsed = json.loads(data)
            rpc_method = parsed.get("method")
            rpc_params = parsed.get("params", [])
            logger.debug(rpc_method)
        except (json.JSONDecodeError, AttributeError):
            pass

        # Prepare headers (remove hop-by-hop headers)
        hop_by_hop = {
            "connection",
            "keep-alive",
            "proxy-authenticate",
            "proxy-authorization",
            "te",
            "trailers",
            "upgrade",
        }
        headers = {
            k: v for k, v in request.headers.items() if k.lower() not in hop_by_hop
        }

        # Rate limit eth_getLogs to prevent disk I/O saturation
        if rpc_method == "eth_getLogs":
            async with self.eth_getLogs_semaphore:
                return await self._try_servers(
                    request,
                    healthy_servers,
                    path,
                    headers,
                    data,
                    rpc_method,
                    rpc_params,
                )
        else:
            return await self._try_servers(
                request, healthy_servers, path, headers, data, rpc_method, rpc_params
            )

    async def _try_servers(
        self,
        request: web.Request,
        healthy_servers: List[ServerHealth],
        path: str,
        headers: dict,
        data: bytes,
        rpc_method: Optional[str] = None,
        rpc_params: Optional[List] = None,
    ) -> web.Response:
        """Try forwarding to each healthy server until one succeeds."""
        last_error = None

        for attempt, target_server in enumerate(healthy_servers):
            target_url = f"{target_server.url.rstrip('/')}{path}"

            result = await self._forward_request(
                request, target_server, target_url, headers, data
            )
            if result is not None:
                # Log eth_getLogs response size for analysis
                if (
                    rpc_method == "eth_getLogs"
                    and hasattr(result, "body")
                    and result.body
                ):
                    try:
                        filter_obj = rpc_params[0] if rpc_params else {}
                        from_block = filter_obj.get("fromBlock", "?")
                        to_block = filter_obj.get("toBlock", "?")
                        address = filter_obj.get("address", "?")
                        # Truncate address for logging
                        if isinstance(address, str) and len(address) > 12:
                            address = address[:12] + "..."
                        elif isinstance(address, list):
                            address = f"[{len(address)} addrs]"
                        logger.debug(
                            f"eth_getLogs: {len(result.body)} bytes, "
                            f"from={from_block} to={to_block} addr={address}"
                        )
                    except Exception as e:
                        logger.warning(
                            f"eth_getLogs: {len(result.body)} bytes (parse error: {e})"
                        )
                return result

            last_error = f"Failed to connect to {target_server.url}"

            # If this isn't the last server, continue to next one
            if attempt < len(healthy_servers) - 1:
                logger.warning("Retrying with next available server...")
                continue

        # All servers failed
        logger.error(
            f"All {len(healthy_servers)} healthy servers failed to handle request"
        )
        return web.Response(
            status=502,
            text=f"Bad Gateway: All backend servers failed. Last error: {last_error}",
            content_type="text/plain",
        )

    async def stop(self):
        """Clean up resources"""
        if self.client_session:
            await self.client_session.close()
        await self.health_monitor.stop()


async def main():
    # Configuration
    config = ProxyConfig(
        servers=list(
            map(lambda x: f"http://{x}:8545", os.environ["INSTANCES"].split(","))
        ),
        reference_addr=os.environ["REFERENCE_ADDR"],
        server_key=os.environ["SERVER_KEY"],
        health_check_interval=60,  # Check every minute
        health_check_timeout=5,  # 5 second timeout for health checks
        proxy_host=os.environ["BIND_ADDR"],
        proxy_port=8545,
        management_host="127.0.0.1",  # Management API only on localhost
        management_port=8081,
        eth_getLogs_limit=int(os.environ.get("ETH_GETLOGS_LIMIT", "75")),
        drain_delay=int(os.environ.get("SHUTDOWN_DRAIN_SECONDS", "40")),
    )

    # Start the proxy
    proxy = HTTPProxy(config)

    proxy_runner, mgmt_runner = await proxy.start_server()

    loop = asyncio.get_running_loop()

    def _on_sigterm():
        # systemd stop / RPM restart: full graceful drain
        asyncio.create_task(proxy._drain_and_shutdown())

    def _on_sigint():
        # interactive Ctrl+C: skip the drain wait, exit promptly
        proxy.draining = True
        proxy._shutdown_event.set()

    loop.add_signal_handler(signal.SIGTERM, _on_sigterm)
    loop.add_signal_handler(signal.SIGINT, _on_sigint)

    print(f"🚀 Proxy server running on http://{config.proxy_host}:{config.proxy_port}")
    print(
        f"🔧 Management API running on http://{config.management_host}:{config.management_port}"
    )
    print(f"📊 Monitoring servers: {config.servers}")
    print(f"🔍 Health checks every {config.health_check_interval} seconds")
    print(f"🚦 eth_getLogs rate limit: {config.eth_getLogs_limit} concurrent requests")
    print("\nManagement endpoints (localhost only):")
    print(
        f"  - Health status: curl http://{config.management_host}:{config.management_port}/health"
    )
    print(
        f"  - Change log level: curl -X POST http://{config.management_host}:{config.management_port}/loglevel -H 'Content-Type: application/json' -d '{{\"level\": \"DEBUG\"}}'"
    )
    print("\nPress Ctrl+C to stop...")

    await proxy._shutdown_event.wait()

    logger.info("Stopping listeners and draining in-flight requests...")
    await proxy_runner.cleanup()  # refuse new conns, let in-flight handlers finish
    await proxy.stop()  # close upstream session + health loop
    await mgmt_runner.cleanup()
    logger.info("Shutdown complete.")


if __name__ == "__main__":
    asyncio.run(main())
