# This file is part of nvitop, the interactive NVIDIA-GPU process viewer.
# License: GNU GPL version 3.

# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
# pylint: disable=invalid-name

from __future__ import annotations

import itertools
import threading
import time
from collections import OrderedDict
from typing import TYPE_CHECKING, ClassVar

from nvitop.tui.library import (
    HOSTNAME,
    IS_SUPERUSER,
    IS_WINDOWS,
    NA,
    USER_CONTEXT,
    USERNAME,
    BufferedHistoryGraph,
    GpuProcess,
    HistoryGraph,
    Selection,
    WideString,
    bytes2human,
    cut_string,
    host,
    wcslen,
)
from nvitop.tui.screens.base import BaseSelectableScreen


if TYPE_CHECKING:
    import curses

    from nvitop.tui.tui import TUI


__all__ = ['ProcessMetricsScreen']


# pylint: disable-next=too-many-branches,too-many-locals
def get_yticks(history: HistoryGraph, y_offset: int) -> list[tuple[int, int]]:
    height = history.height
    baseline = history.baseline
    bound = history.bound
    max_bound = history.max_bound
    scale: float = history.scale  # type: ignore[attr-defined]
    upsidedown = history.upsidedown

    def p2h_f(p: int) -> float:
        return 0.01 * scale * p * (max_bound - baseline) * (height - 1) / (bound - baseline)

    max_height = height - 2
    percentages = (1, 2, 4, 5, 8, 10, 20, 40, 50, 80, 100, 200, 400, 500, 800, 1000)
    h2p = {}
    p2h = {}
    h2e = {}
    for p in percentages:
        h_f = p2h_f(p)
        p2h[p] = h = int(h_f)
        if h not in h2p:
            if h < max_height:
                h2p[h] = p
                h2e[h] = abs(h_f - h) / p
        elif abs(h_f - h) / p < h2e[h]:
            h2p[h] = p
            h2e[h] = abs(h_f - h) / p
    h2p = sorted(h2p.items())
    ticks = []
    if len(h2p) >= 2:
        (hm1, pm1), (h2, p2) = h2p[-2:]
        if height < 12:
            ticks = [(hm1, pm1)] if h2e[hm1] < h2e[h2] else [(h2, p2)]
        else:
            ticks = [(h2, p2)]
            if p2 % 2 == 0:
                p1 = p2 // 2
                h1 = int(p2h_f(p1))
                p3 = 3 * p1
                h3 = int(p2h_f(p3))
                if p1 >= 3:
                    ticks.append((h1, p1))
                    if h2 < h3 < max_height:
                        ticks.append((h3, p3))
    else:
        ticks = list(h2p)
    if not upsidedown:
        ticks = [(height - 1 - h, p) for h, p in ticks]
    return [(h + y_offset, p) for h, p in ticks]


class ProcessMetricsScreen(BaseSelectableScreen):  # pylint: disable=too-many-instance-attributes
    NAME: ClassVar[str] = 'process-metrics'
    SNAPSHOT_INTERVAL: ClassVar[float] = 0.5

    def __init__(self, *, win: curses.window, root: TUI) -> None:
        super().__init__(win, root)

        self.selection: Selection = Selection(self)
        self.used_gpu_memory: HistoryGraph | None = None
        self.gpu_sm_utilization: HistoryGraph | None = None
        self.cpu_percent: HistoryGraph | None = None
        self.used_host_memory: HistoryGraph | None = None

        self.enabled: bool = False
        self.snapshot_lock = threading.Lock()
        self._snapshot_daemon = threading.Thread(
            name='process-metrics-snapshot-daemon',
            target=self._snapshot_target,
            daemon=True,
        )
        self._daemon_running = threading.Event()

        self.x, self.y = root.x, root.y
        self.width, self.height = root.width, root.height
        self.left_width: int = max(20, (self.width - 3) // 2)
        self.right_width: int = max(20, (self.width - 2) // 2)
        self.upper_height: int = max(5, (self.height - 5 - 3) // 2)
        self.lower_height: int = max(5, (self.height - 5 - 2) // 2)

    @property
    def visible(self) -> bool:
        return self._visible

    @visible.setter
    def visible(self, value: bool) -> None:
        if self._visible != value:
            self.need_redraw = True
            self._visible = value
        if self.visible:
            self._daemon_running.set()
            try:
                self._snapshot_daemon.start()
            except RuntimeError:
                pass
            self.take_snapshots()
        else:
            self.focused = False

    def enable(self, state: bool = True) -> None:
        if not self.selection.is_set() or not state:
            self.disable()
            return

        total_host_memory = host.virtual_memory().total
        total_host_memory_human = bytes2human(total_host_memory)
        total_gpu_memory = self.process.device.memory_total()
        total_gpu_memory_human = bytes2human(total_gpu_memory)

        def format_cpu_percent(value: float) -> str:
            if value is NA:  # type: ignore[comparison-overlap]
                return f'CPU: {value}'
            return f'CPU: {value:.1f}%'

        def format_max_cpu_percent(value: float) -> str:
            if value is NA:  # type: ignore[comparison-overlap]
                return f'MAX CPU: {value}'
            return f'MAX CPU: {value:.1f}%'

        def format_host_memory(value: float) -> str:
            if value is NA:  # type: ignore[comparison-overlap]
                return f'HOST-MEM: {value}'
            return (
                f'HOST-MEM: {bytes2human(value)} '
                f'({round(100.0 * value / total_host_memory, 1):.1f}%)'
            )

        def format_max_host_memory(value: float) -> str:
            if value is NA:  # type: ignore[comparison-overlap]
                return f'MAX HOST-MEM: {value}'
            return (
                f'MAX HOST-MEM: {bytes2human(value)} '
                f'({round(100.0 * value / total_host_memory, 1):.1f}%) '
                f'/ {total_host_memory_human}'
            )

        def format_gpu_memory(value: float) -> str:
            if value is not NA and total_gpu_memory is not NA:  # type: ignore[comparison-overlap]
                return (
                    f'GPU-MEM: {bytes2human(value)} '
                    f'({round(100.0 * value / total_gpu_memory, 1):.1f}%)'
                )
            return f'GPU-MEM: {value}'

        def format_max_gpu_memory(value: float) -> str:
            if value is not NA and total_gpu_memory is not NA:  # type: ignore[comparison-overlap]
                return (
                    f'MAX GPU-MEM: {bytes2human(value)} '
                    f'({round(100.0 * value / total_gpu_memory, 1):.1f}%) '
                    f'/ {total_gpu_memory_human}'
                )
            return f'MAX GPU-MEM: {value}'

        def format_sm(value: float) -> str:
            if value is NA:  # type: ignore[comparison-overlap]
                return f'GPU-SM: {value}'
            return f'GPU-SM: {value:.1f}%'

        def format_max_sm(value: float) -> str:
            if value is NA:  # type: ignore[comparison-overlap]
                return f'MAX GPU-SM: {value}'
            return f'MAX GPU-SM: {value:.1f}%'

        with self.snapshot_lock:
            self.cpu_percent = BufferedHistoryGraph(
                interval=1.0,
                upperbound=1000.0,
                width=self.left_width,
                height=self.upper_height,
                baseline=0.0,
                upsidedown=False,
                dynamic_bound=True,
                min_bound=10.0,
                init_bound=100.0,
                format=format_cpu_percent,
                max_format=format_max_cpu_percent,
            )
            self.used_host_memory = BufferedHistoryGraph(
                interval=1.0,
                upperbound=total_host_memory,
                width=self.left_width,
                height=self.lower_height,
                baseline=0.0,
                upsidedown=True,
                dynamic_bound=True,
                format=format_host_memory,
                max_format=format_max_host_memory,
            )
            self.used_gpu_memory = BufferedHistoryGraph(
                interval=1.0,
                upperbound=total_gpu_memory or 1.0,  # type: ignore[arg-type]
                width=self.right_width,
                height=self.upper_height,
                baseline=0.0,
                upsidedown=False,
                dynamic_bound=True,
                format=format_gpu_memory,
                max_format=format_max_gpu_memory,
            )
            self.gpu_sm_utilization = BufferedHistoryGraph(
                interval=1.0,
                upperbound=100.0,
                width=self.right_width,
                height=self.lower_height,
                baseline=0.0,
                upsidedown=True,
                dynamic_bound=True,
                format=format_sm,
                max_format=format_max_sm,
            )
            self.cpu_percent.scale = 0.1  # type: ignore[attr-defined]
            self.used_host_memory.scale = 1.0  # type: ignore[attr-defined]
            self.used_gpu_memory.scale = 1.0  # type: ignore[attr-defined]
            self.gpu_sm_utilization.scale = 1.0  # type: ignore[attr-defined]

            self._daemon_running.set()
            try:
                self._snapshot_daemon.start()
            except RuntimeError:
                pass
            self.enabled = True

        self.take_snapshots()
        self.update_size()

    def disable(self) -> None:
        with self.snapshot_lock:
            self._daemon_running.clear()
            self.enabled = False
            self.cpu_percent = None
            self.used_host_memory = None
            self.used_gpu_memory = None
            self.gpu_sm_utilization = None

    @property
    def process(self) -> GpuProcess:
        return self.selection.process  # type: ignore[return-value]

    @process.setter
    def process(self, value: GpuProcess) -> None:
        self.selection.process = value
        self.enable()

    @classmethod
    def set_snapshot_interval(cls, interval: float) -> None:
        assert interval > 0.0
        interval = float(interval)

        cls.SNAPSHOT_INTERVAL = min(interval / 3.0, 1.0)

    def take_snapshots(self) -> None:
        with self.snapshot_lock:
            if not self.selection.is_set() or not self.enabled:
                return

            with GpuProcess.failsafe():
                self.process.device.as_snapshot()
                self.process.update_gpu_status()
                snapshot = self.process.as_snapshot()

                assert self.cpu_percent is not None
                assert self.used_host_memory is not None
                assert self.used_gpu_memory is not None
                assert self.gpu_sm_utilization is not None
                self.cpu_percent.add(snapshot.cpu_percent)
                self.used_host_memory.add(snapshot.host_memory)
                self.used_gpu_memory.add(snapshot.gpu_memory)
                self.gpu_sm_utilization.add(snapshot.gpu_sm_utilization)

    def _snapshot_target(self) -> None:
        while True:
            self._daemon_running.wait()
            self.take_snapshots()
            time.sleep(self.SNAPSHOT_INTERVAL)

    def update_size(self, termsize: tuple[int, int] | None = None) -> tuple[int, int]:
        n_term_lines, n_term_cols = termsize = super().update_size(termsize=termsize)

        self.width = n_term_cols - self.x
        self.height = n_term_lines - self.y
        self.left_width = max(20, (self.width - 3) // 2)
        self.right_width = max(20, (self.width - 2) // 2)
        self.upper_height = max(5, (self.height - 8) // 2)
        self.lower_height = max(5, (self.height - 7) // 2)
        self.need_redraw = True

        with self.snapshot_lock:
            if self.enabled:
                assert self.cpu_percent is not None
                assert self.used_host_memory is not None
                assert self.used_gpu_memory is not None
                assert self.gpu_sm_utilization is not None
                self.cpu_percent.graph_size = (self.left_width, self.upper_height)
                self.used_host_memory.graph_size = (self.left_width, self.lower_height)
                self.used_gpu_memory.graph_size = (self.right_width, self.upper_height)
                self.gpu_sm_utilization.graph_size = (self.right_width, self.lower_height)

        return termsize

    def frame_lines(self) -> list[str]:
        line = '│' + ' ' * self.left_width + '│' + ' ' * self.right_width + '│'
        return [
            '╒' + '═' * (self.width - 2) + '╕',
            '│ {} │'.format('Process:'.ljust(self.width - 4)),
            '│ {} │'.format('GPU'.ljust(self.width - 4)),
            '╞' + '═' * (self.width - 2) + '╡',
            '│' + ' ' * (self.width - 2) + '│',
            '╞' + '═' * self.left_width + '╤' + '═' * self.right_width + '╡',
            *([line] * self.upper_height),
            '├' + '─' * self.left_width + '┼' + '─' * self.right_width + '┤',
            *([line] * self.lower_height),
            '╘' + '═' * self.left_width + '╧' + '═' * self.right_width + '╛',
        ]

    def poke(self) -> None:
        if self.visible and not self._daemon_running.is_set():
            self._daemon_running.set()
            try:
                self._snapshot_daemon.start()
            except RuntimeError:
                pass
            self.take_snapshots()

        super().poke()

    def draw(self) -> None:  # pylint: disable=too-many-statements,too-many-locals,too-many-branches
        self.color_reset()

        assert self.used_gpu_memory is not None
        assert self.gpu_sm_utilization is not None
        assert self.cpu_percent is not None
        assert self.used_host_memory is not None

        if self.need_redraw:
            for y, line in enumerate(self.frame_lines(), start=self.y):
                self.addstr(y, self.x, line)

            context_width = wcslen(USER_CONTEXT)
            if not IS_WINDOWS or len(USER_CONTEXT) == context_width:
                # Do not support windows-curses with wide characters
                username_width = wcslen(USERNAME)
                hostname_width = wcslen(HOSTNAME)
                offset = self.x + self.width - context_width - 2
                self.addstr(self.y + 1, self.x + offset, USER_CONTEXT)
                self.color_at(self.y + 1, self.x + offset, width=context_width, attr='bold')
                self.color_at(
                    self.y + 1,
                    self.x + offset,
                    width=username_width,
                    fg=('yellow' if IS_SUPERUSER else 'magenta'),
                    attr='bold',
                )
                self.color_at(
                    self.y + 1,
                    self.x + offset + username_width + 1,
                    width=hostname_width,
                    fg='green',
                    attr='bold',
                )

            for offset, string in (
                (19, '╴30s├'),
                (34, '╴60s├'),
                (65, '╴120s├'),
                (95, '╴180s├'),
                (125, '╴240s├'),
                (155, '╴300s├'),
            ):
                for x_offset, width in (
                    (self.x + 1 + self.left_width, self.left_width),
                    (self.x + 1 + self.left_width + 1 + self.right_width, self.right_width),
                ):
                    if offset > width:
                        break
                    self.addstr(self.y + self.upper_height + 6, x_offset - offset, string)
                    self.color_at(
                        self.y + self.upper_height + 6,
                        x_offset - offset + 1,
                        width=len(string) - 2,
                        attr='dim',
                    )

        with self.snapshot_lock:
            process = self.process.snapshot
            columns = OrderedDict(
                [
                    (' GPU', self.process.device.display_index.rjust(4)),
                    ('PID  ', f'{str(process.pid).rjust(3)} {process.type}'),
                    (
                        'USER',
                        WideString(
                            cut_string(
                                WideString(process.username).rjust(4),
                                maxlen=32,
                                padstr='+',
                            ),
                        ),
                    ),
                    (' GPU-MEM', process.gpu_memory_human.rjust(8)),
                    (' %SM', str(process.gpu_sm_utilization).rjust(4)),
                    ('%GMBW', str(process.gpu_memory_utilization).rjust(5)),
                    ('%ENC', str(process.gpu_encoder_utilization).rjust(4)),
                    ('%DEC', str(process.gpu_encoder_utilization).rjust(4)),
                    ('  %CPU', process.cpu_percent_string.rjust(6)),
                    (' %MEM', process.memory_percent_string.rjust(5)),
                    (' TIME', (' ' + process.running_time_human).rjust(5)),
                ],
            )

            x = self.x + 1
            header = ''
            fields = WideString()
            no_break = True
            for i, (col, value) in enumerate(columns.items()):
                width = len(value)
                if x + width < self.width - 2:
                    if i == 0:
                        header += col.rjust(width)
                        fields += value
                    else:
                        header += ' ' + col.rjust(width)
                        fields += ' ' + value
                    x = self.x + 1 + len(fields)
                else:
                    no_break = False
                    break

            self.addstr(self.y + 2, self.x + 1, header.ljust(self.width - 2))
            self.addstr(self.y + 4, self.x + 1, str(fields.ljust(self.width - 2)))
            self.color_at(
                self.y + 4,
                self.x + 1,
                width=4,
                fg=self.process.device.snapshot.display_color,
            )

            if no_break:
                x = self.x + 1 + len(fields) + 2
                if x + 4 < self.width - 2:
                    self.addstr(
                        self.y + 2,
                        x,
                        cut_string('COMMAND', self.width - x - 2, padstr='..').ljust(
                            self.width - x - 2,
                        ),
                    )
                    if process.is_zombie or process.no_permissions:
                        self.color(fg='yellow')
                    elif process.is_gone:
                        self.color(fg='red')
                    self.addstr(
                        self.y + 4,
                        x,
                        cut_string(
                            WideString(process.command).ljust(self.width - x - 2),
                            self.width - x - 2,
                            padstr='..',
                        ),
                    )

            self.color(fg='cyan')
            for y, line in enumerate(self.cpu_percent.graph, start=self.y + 6):
                self.addstr(y, self.x + 1, line)

            self.color(fg='magenta')
            for y, line in enumerate(
                self.used_host_memory.graph,
                start=self.y + self.upper_height + 7,
            ):
                self.addstr(y, self.x + 1, line)

            if self.TERM_256COLOR:
                scale = (self.used_gpu_memory.bound / self.used_gpu_memory.max_bound) / (
                    self.upper_height - 1
                )
                for i, (y, line) in enumerate(
                    enumerate(self.used_gpu_memory.graph, start=self.y + 6),
                ):
                    self.addstr(
                        y,
                        self.x + self.left_width + 2,
                        line,
                        self.get_fg_bg_attr(fg=(self.upper_height - i - 1) * scale),
                    )

                scale = (self.gpu_sm_utilization.bound / self.gpu_sm_utilization.max_bound) / (
                    self.lower_height - 1
                )
                for i, (y, line) in enumerate(
                    enumerate(self.gpu_sm_utilization.graph, start=self.y + self.upper_height + 7),
                ):
                    self.addstr(
                        y,
                        self.x + self.left_width + 2,
                        line,
                        self.get_fg_bg_attr(fg=i * scale),
                    )
            else:
                self.color(fg=self.process.device.snapshot.memory_display_color)
                for y, line in enumerate(self.used_gpu_memory.graph, start=self.y + 6):
                    self.addstr(y, self.x + self.left_width + 2, line)

                self.color(fg=self.process.device.snapshot.gpu_display_color)
                for y, line in enumerate(
                    self.gpu_sm_utilization.graph,
                    start=self.y + self.upper_height + 7,
                ):
                    self.addstr(y, self.x + self.left_width + 2, line)

            self.color_reset()
            self.addstr(self.y + 6, self.x + 1, f' {self.cpu_percent.max_value_string()} ')
            self.addstr(self.y + 7, self.x + 5, f' {self.cpu_percent} ')
            self.addstr(
                self.y + self.upper_height + self.lower_height + 5,
                self.x + 5,
                f' {self.used_host_memory} ',
            )
            self.addstr(
                self.y + self.upper_height + self.lower_height + 6,
                self.x + 1,
                ' {} '.format(
                    cut_string(
                        self.used_host_memory.max_value_string(),
                        maxlen=self.left_width - 2,
                        padstr='..',
                    ),
                ),
            )
            self.addstr(
                self.y + 6,
                self.x + self.left_width + 2,
                ' {} '.format(
                    cut_string(
                        self.used_gpu_memory.max_value_string(),
                        maxlen=self.right_width - 2,
                        padstr='..',
                    ),
                ),
            )
            self.addstr(self.y + 7, self.x + self.left_width + 6, f' {self.used_gpu_memory} ')
            self.addstr(
                self.y + self.upper_height + self.lower_height + 5,
                self.x + self.left_width + 6,
                f' {self.gpu_sm_utilization} ',
            )
            self.addstr(
                self.y + self.upper_height + self.lower_height + 6,
                self.x + self.left_width + 2,
                f' {self.gpu_sm_utilization.max_value_string()} ',
            )

            for y in range(self.y + 6, self.y + 6 + self.upper_height):
                self.addstr(y, self.x, '│')
                self.addstr(y, self.x + self.left_width + 1, '│')
            for y in range(
                self.y + self.upper_height + 7,
                self.y + self.upper_height + self.lower_height + 7,
            ):
                self.addstr(y, self.x, '│')
                self.addstr(y, self.x + self.left_width + 1, '│')

            self.color(attr='dim')
            for y, p in itertools.chain(
                get_yticks(self.cpu_percent, self.y + 6),
                get_yticks(self.used_host_memory, self.y + self.upper_height + 7),
            ):
                self.addstr(y, self.x, f'├╴{p}% ')
                self.color_at(y, self.x, width=2, attr=0)
            x = self.x + self.left_width + 1
            for y, p in itertools.chain(
                get_yticks(self.used_gpu_memory, self.y + 6),
                get_yticks(self.gpu_sm_utilization, self.y + self.upper_height + 7),
            ):
                self.addstr(y, x, f'├╴{p}% ')
                self.color_at(y, x, width=2, attr=0)

    def destroy(self) -> None:
        super().destroy()
        self._daemon_running.clear()

    def press(self, key: int) -> bool:
        self.root.keymaps.use_keymap('process-metrics')
        return self.root.press(key)
