#!/usr/bin/python3
# SPDX-License-Identifier: GPL-2.0-or-later
#
# /usr/sbin/pocketds-fancontrol
#
# Userspace PWM fan controller for the AYANEO Pocket DS (SM8550).
#
# Background: the QCS8550 thermal-zone DT used to bind every cpuss/gpuss
# trip-point to <&pwm_fan> via cooling-maps, letting the kernel ramp the
# fan automatically. We dropped those maps in
# qcs8550-ayaneo-pocket-common.dtsi (mirroring ROCKNIX 524d0c9) so the
# fan now sits idle unless something writes its hwmon pwm1 -- which is
# us.
#
# We discover the pwm-fan hwmon entry, average temperatures across all
# real thermal zones (skipping PMIC + battery), and step pwm1 according
# to a profile. Profile is read from /etc/pocketds-fancontrol/profile
# (one of: auto, quiet, moderate, aggressive, custom, off) and reloaded
# whenever the file's mtime changes -- no SIGHUP plumbing needed,
# `install -m 0664` lets the indicator app rewrite it via pkexec/sudo.
#
# Custom curves live in /etc/pocketds-fancontrol/custom.conf as
# `tempC=pwm` lines (e.g. `60=80`).

import errno
import glob
import os
import signal
import sys
import time

PROFILE_FILE = "/etc/pocketds-fancontrol/profile"
CUSTOM_FILE = "/etc/pocketds-fancontrol/custom.conf"
STATE_DIR = "/run/pocketds-fancontrol"
STATE_FILE = os.path.join(STATE_DIR, "state")

POLL_INTERVAL = 3.0
PROFILE_RECHECK_INTERVAL = 5.0

# Thermal-zone `type` strings to ignore. PMIC + battery sensors don't
# track die temperature and would skew the average low.
SKIP_ZONE_TYPES = {
    "pm8550-thermal", "pm8550b-thermal", "pm8550ve-thermal",
    "pmk8550-thermal", "battery",
}

# (max_celsius, t6, t5, t4, t3, t2, t1) → (255, 204, 153, 119, 102, 77, 51, 0)
PWM_STEPS = (255, 204, 153, 119, 102, 77, 51, 0)

PROFILES = {
    "aggressive": (80, 75, 70, 65, 60, 55, 50),
    "moderate":   (85, 80, 75, 70, 65, 60, 55),
    "auto":       (85, 80, 75, 70, 65, 60, 55),
    "quiet":      (95, 85, 80, 75, 70, 65, 60),
}


def log(msg):
    print(f"pocketds-fancontrol: {msg}", flush=True)


def find_pwm_path():
    for pwm in sorted(glob.glob("/sys/class/hwmon/hwmon*/pwm1")):
        # The pwm-fan driver exposes pwm1 only.
        try:
            with open(os.path.join(os.path.dirname(pwm), "name")) as f:
                name = f.read().strip()
        except OSError:
            name = ""
        if name in ("pwmfan", "pwm-fan") or os.path.exists(pwm + "_enable"):
            return pwm
    # Fall back to the first pwm1 we find.
    matches = sorted(glob.glob("/sys/class/hwmon/hwmon*/pwm1"))
    return matches[0] if matches else None


def find_temp_paths():
    paths = []
    for zone in sorted(glob.glob("/sys/devices/virtual/thermal/thermal_zone*")):
        try:
            with open(os.path.join(zone, "type")) as f:
                ztype = f.read().strip()
        except OSError:
            continue
        if ztype in SKIP_ZONE_TYPES:
            continue
        tpath = os.path.join(zone, "temp")
        if os.path.exists(tpath):
            paths.append(tpath)
    return paths


def read_avg_temp_mc(paths):
    total = 0
    n = 0
    for p in paths:
        try:
            with open(p) as f:
                v = int(f.read().strip())
            total += v
            n += 1
        except (OSError, ValueError):
            continue
    return total // n if n else 0


def parse_custom_conf(path):
    # Each line: tempC=pwm (0-255). Sorted ascending by temp; the first
    # entry whose temp >= current applies.
    table = []
    try:
        with open(path) as f:
            for line in f:
                line = line.split("#", 1)[0].strip()
                if not line:
                    continue
                if "=" not in line:
                    continue
                t_str, p_str = line.split("=", 1)
                try:
                    t = int(t_str.strip())
                    p = int(p_str.strip())
                except ValueError:
                    continue
                p = max(0, min(255, p))
                table.append((t, p))
    except OSError:
        return None
    if not table:
        return None
    table.sort()
    return table


def read_profile():
    try:
        with open(PROFILE_FILE) as f:
            v = f.read().strip().lower()
    except OSError:
        v = ""
    if v not in PROFILES and v not in ("custom", "off"):
        v = "moderate"
    return v


def pwm_for_thresholds(temp_c, thresholds):
    # thresholds = (max, t6, t5, t4, t3, t2, t1) descending
    for i, t in enumerate(thresholds):
        if temp_c >= t:
            return PWM_STEPS[i]
    return 0


def pwm_for_custom(temp_c, table):
    # First (asc-sorted) entry where t >= temp_c wins; if none, use last.
    last_pwm = 0
    for t, p in table:
        if temp_c < t:
            return last_pwm
        last_pwm = p
    return last_pwm


def write_atomic(path, content):
    tmp = path + ".tmp"
    try:
        with open(tmp, "w") as f:
            f.write(content)
        os.replace(tmp, path)
    except OSError as e:
        log(f"write {path}: {e}")


def write_pwm(pwm_path, value):
    try:
        with open(pwm_path, "w") as f:
            f.write(str(value))
    except OSError as e:
        log(f"write {pwm_path}={value}: {e}")


def enable_pwm(pwm_path):
    en = pwm_path + "_enable"
    if not os.path.exists(en):
        return
    try:
        with open(en) as f:
            cur = f.read().strip()
    except OSError:
        cur = ""
    if cur != "1":
        try:
            with open(en, "w") as f:
                f.write("1")
        except OSError as e:
            log(f"enable {en}: {e}")


def disable_pwm(pwm_path):
    en = pwm_path + "_enable"
    if not os.path.exists(en):
        return
    try:
        with open(en, "w") as f:
            f.write("0")
    except OSError:
        pass


def main():
    pwm_path = find_pwm_path()
    if not pwm_path:
        log("no pwm1 hwmon entry found; nothing to control")
        return 1
    log(f"pwm path: {pwm_path}")

    temp_paths = find_temp_paths()
    if not temp_paths:
        log("no thermal zones found; nothing to read")
        return 1
    log(f"thermal zones: {len(temp_paths)} sensors")

    enable_pwm(pwm_path)

    os.makedirs(STATE_DIR, exist_ok=True)
    os.chmod(STATE_DIR, 0o755)

    cleanup = {"done": False}

    def _shutdown(*_):
        if cleanup["done"]:
            return
        cleanup["done"] = True
        log("shutting down; releasing fan to kernel default")
        # Drop the fan back to off so we don't leave it spinning at full
        # tilt if the daemon dies.
        write_pwm(pwm_path, 0)
        disable_pwm(pwm_path)
        sys.exit(0)

    signal.signal(signal.SIGTERM, _shutdown)
    signal.signal(signal.SIGINT, _shutdown)
    signal.signal(signal.SIGHUP, lambda *_: None)  # reread happens on its own

    profile = read_profile()
    custom_table = parse_custom_conf(CUSTOM_FILE) if profile == "custom" else None
    profile_mtime = _mtime(PROFILE_FILE)
    custom_mtime = _mtime(CUSTOM_FILE)
    last_profile_check = 0.0
    last_pwm = -1

    log(f"profile={profile}")

    while True:
        now = time.monotonic()

        if now - last_profile_check >= PROFILE_RECHECK_INTERVAL:
            last_profile_check = now
            new_profile_mtime = _mtime(PROFILE_FILE)
            new_custom_mtime = _mtime(CUSTOM_FILE)
            if new_profile_mtime != profile_mtime:
                profile_mtime = new_profile_mtime
                profile = read_profile()
                log(f"profile changed → {profile}")
            if profile == "custom" and new_custom_mtime != custom_mtime:
                custom_mtime = new_custom_mtime
                custom_table = parse_custom_conf(CUSTOM_FILE)
                log(f"custom curve reloaded ({len(custom_table or [])} pts)")

        avg_mc = read_avg_temp_mc(temp_paths)
        avg_c = avg_mc // 1000

        if profile == "off":
            pwm = 0
        elif profile == "custom":
            if not custom_table:
                # Fall back to moderate if custom curve is missing/empty.
                pwm = pwm_for_thresholds(avg_c, PROFILES["moderate"])
            else:
                pwm = pwm_for_custom(avg_c, custom_table)
        else:
            thresholds = PROFILES.get(profile, PROFILES["moderate"])
            pwm = pwm_for_thresholds(avg_c, thresholds)

        if pwm != last_pwm:
            write_pwm(pwm_path, pwm)
            last_pwm = pwm

        write_atomic(
            STATE_FILE,
            f"profile={profile}\ntemp_c={avg_c}\npwm={pwm}\n",
        )

        time.sleep(POLL_INTERVAL)


def _mtime(path):
    try:
        return os.stat(path).st_mtime
    except OSError:
        return 0.0


if __name__ == "__main__":
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        sys.exit(0)
