#
# Copyright 2017 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA
#
# Refer to the README and COPYING files for full details of the license
#
from __future__ import absolute_import

from contextlib import contextmanager
import logging
import threading

import libvirt

from vdsm.common import response
from vdsm.virt.vmdevices.storage import Drive, DISK_TYPE, BLOCK_THRESHOLD
from vdsm.virt.vmdevices import hwclass
from vdsm.virt import drivemonitor
from vdsm.virt import vm
from vdsm.virt import vmstatus

from testlib import expandPermutations, permutations
from testlib import make_config
from testlib import maybefail
from testlib import VdsmTestCase
import vmfakelib as fake

from monkeypatch import MonkeyPatchScope


MB = 1024 ** 2
GB = 1024 ** 3


CHUNK_SIZE = 1 * GB
CHUNK_PCT = 50


@contextmanager
def make_env(events_enabled):
    log = logging.getLogger('test')

    cfg = make_config([
        ('irs', 'enable_block_threshold_event',
            'true' if events_enabled else 'false')])

    # the Drive class use those two tunables as class constants.
    with MonkeyPatchScope([
        (Drive, 'VOLWM_CHUNK_SIZE', CHUNK_SIZE),
        (Drive, 'VOLWM_FREE_PCT', CHUNK_PCT),
        (drivemonitor, 'config', cfg),
    ]):
        # storage does not validate the UUIDs, so we use phony names
        # for brevity
        drives = [
            Drive(log, **drive_config(
                format='cow',
                diskType=DISK_TYPE.BLOCK,
                index=0,
                volumeID='volume_0',
                poolID='pool_0',
                imageID='image_0',
                domainID='domain_0',

            )),
            Drive(log, **drive_config(
                format='cow',
                diskType=DISK_TYPE.BLOCK,
                index=1,
                volumeID='volume_1',
                poolID='pool_0',
                imageID='image_1',
                domainID='domain_0',
            )),
        ]
        # TODO: add raw/block drive and qcow2/file drive.
        # check we don't try to monitor or extend those drives.

        dom = FakeDomain()
        irs = FakeIRS()

        for drive in drives:
            capacity, allocation, physical = dom.blockInfo(drive.path, 0)
            key = (drive.domainID, drive.poolID,
                   drive.imageID, drive.volumeID)
            irs.volume_sizes[key] = capacity
            drive.apparentsize = capacity

        cif = FakeClientIF()
        cif.irs = irs
        yield FakeVM(cif, dom, drives), dom, drives


def allocation_threshold_for_resize_mb(block_info, drive):
    return block_info['physical'] - drive.watermarkLimit


class DiskExtensionTestBase(VdsmTestCase):
    # helpers

    def check_extension(self, drive_info, drive_obj, extension_req):
        poolID, volInfo, newSize, func = extension_req

        # we do the minimal validation. Specific test(s) should
        # check that the callable actually finishes the extension process.
        self.assertTrue(callable(func))

        self.assertEqual(drive_obj.poolID, poolID)

        expected_size = drive_obj.getNextVolumeSize(
            drive_info['physical'], drive_info['capacity'])
        self.assertEqual(expected_size, newSize)

        self.assertEqual(expected_size, volInfo['newSize'])
        self.assertEqual(drive_obj.name, volInfo['name'])

        if drive_obj.isDiskReplicationInProgress():
            self.assertEqual(drive_obj.diskReplicate['domainID'],
                             volInfo['domainID'])
            self.assertEqual(drive_obj.diskReplicate['imageID'],
                             volInfo['imageID'])
            self.assertEqual(drive_obj.diskReplicate['poolID'],
                             volInfo['poolID'])
            self.assertEqual(drive_obj.diskReplicate['volumeID'],
                             volInfo['volumeID'])
        else:
            self.assertEqual(drive_obj.domainID, volInfo['domainID'])
            self.assertEqual(drive_obj.imageID, volInfo['imageID']),
            self.assertEqual(drive_obj.poolID, volInfo['poolID']),
            self.assertEqual(drive_obj.volumeID, volInfo['volumeID'])


class TestDiskExtensionWithPolling(DiskExtensionTestBase):

    def test_no_extension_allocation_below_watermark(self):

        with make_env(events_enabled=False) as (testvm, dom, drives):
            vda = dom.block_info['/virtio/0']
            vda['allocation'] = 0 * MB
            vdb = dom.block_info['/virtio/1']
            vdb['allocation'] = allocation_threshold_for_resize_mb(
                vdb, drives[1]) - 1 * MB

            extended = testvm.monitor_drives()

        self.assertEqual(extended, False)

    def test_no_extension_maximum_size_reached(self):

        with make_env(events_enabled=False) as (testvm, dom, drives):
            vda = dom.block_info['/virtio/0']
            vda['allocation'] = 0 * MB
            vdb = dom.block_info['/virtio/1']
            max_size = drives[1].getMaxVolumeSize(vdb['capacity'])
            vdb['allocation'] = max_size
            vdb['physical'] = max_size
            extended = testvm.monitor_drives()

        self.assertEqual(extended, False)

    def test_extend_drive_allocation_crosses_watermark_limit(self):

        with make_env(events_enabled=False) as (testvm, dom, drives):
            vda = dom.block_info['/virtio/0']
            vda['allocation'] = 0 * MB
            vdb = dom.block_info['/virtio/1']
            vdb['allocation'] = allocation_threshold_for_resize_mb(
                vdb, drives[1]) + 1 * MB

            extended = testvm.monitor_drives()

        self.assertEqual(extended, True)
        self.assertEqual(len(testvm.cif.irs.extensions), 1)
        self.check_extension(vdb, drives[1], testvm.cif.irs.extensions[0])

    def test_extend_drive_allocation_equals_next_size(self):

        with make_env(events_enabled=False) as (testvm, dom, drives):
            vda = dom.block_info['/virtio/0']
            vda['allocation'] = drives[0].getNextVolumeSize(
                vda['physical'], vda['capacity'])
            vdb = dom.block_info['/virtio/1']
            vdb['allocation'] = 0 * MB
            extended = testvm.monitor_drives()

        self.assertEqual(extended, True)
        self.assertEqual(len(testvm.cif.irs.extensions), 1)
        self.check_extension(vda, drives[0], testvm.cif.irs.extensions[0])

    def test_stop_extension_loop_on_improbable_request(self):

        with make_env(events_enabled=False) as (testvm, dom, drives):
            vda = dom.block_info['/virtio/0']
            vda['allocation'] = (
                drives[0].getNextVolumeSize(
                    vda['physical'], vda['capacity']) + 1 * MB)
            vdb = dom.block_info['/virtio/1']
            vdb['allocation'] = 0 * MB
            extended = testvm.monitor_drives()

        self.assertEqual(extended, False)
        self.assertEqual(dom.info()[0], libvirt.VIR_DOMAIN_PAUSED)

    # TODO: add the same test for disk replicas.
    def test_vm_resumed_after_drive_extended(self):

        with make_env(events_enabled=False) as (testvm, dom, drives):
            testvm.pause()

            vda = dom.block_info['/virtio/0']
            vda['allocation'] = 0 * MB
            vdb = dom.block_info['/virtio/1']  # shortcut
            vdb['allocation'] = allocation_threshold_for_resize_mb(
                vdb, drives[1]) + 1 * MB

            extended = testvm.monitor_drives()
            self.assertEqual(extended, True)
            self.assertEqual(len(testvm.cif.irs.extensions), 1)

            # Simulate completed extend operation, invoking callback

            simulate_extend_callback(testvm.cif.irs, extension_id=0)

        self.assertEqual(testvm.lastStatus, vmstatus.UP)
        self.assertEqual(dom.info()[0], libvirt.VIR_DOMAIN_RUNNING)

    # TODO: add test with storage failures in the extension flow


@expandPermutations
class TestDiskExtensionWithEvents(DiskExtensionTestBase):

    # TODO: missing tests:
    # - call extend_if_needed when drive.threshold_state is EXCEEDED
    #   -> extend
    # FIXME: already covered by existing cases?

    def test_extend_using_events(self):
        with make_env(events_enabled=True) as (testvm, dom, drives):

            # first run: does nothing but set the block thresholds
            testvm.monitor_drives()

            # Simulate writing to drive vdb
            vdb = dom.block_info['/virtio/1']

            alloc = allocation_threshold_for_resize_mb(
                vdb, drives[1]) + 1 * MB

            vdb['allocation'] = alloc

            # Simulating block threshold event
            testvm.drive_monitor.on_block_threshold(
                'vdb', '/virtio/1', alloc, 1 * MB)

            drv = drives[1]
            self.assertEqual(drv.threshold_state, BLOCK_THRESHOLD.EXCEEDED)

            # Simulating periodic check
            extended = testvm.monitor_drives()
            self.assertEqual(extended, True)
            self.assertEqual(len(testvm.cif.irs.extensions), 1)
            self.check_extension(vdb, drives[1], testvm.cif.irs.extensions[0])
            self.assertEqual(drv.threshold_state, BLOCK_THRESHOLD.EXCEEDED)

            # Simulate completed extend operation, invoking callback

            simulate_extend_callback(testvm.cif.irs, extension_id=0)

            drv = drives[1]
            self.assertEqual(drv.threshold_state, BLOCK_THRESHOLD.SET)

    @permutations([
        # replicating, threshold
        (False, 1536 * MB),
        (True, 1024 * MB),
    ])
    def test_set_new_threshold_when_state_unset(self, replicating, threshold):
        with make_env(events_enabled=True) as (testvm, dom, drives):

            dom.block_info['/virtio/0'] = {
                'capacity': 4 * GB,
                'allocation': 0 * GB,
                'physical': 2 * GB,
            }

            vda = drives[0]
            if replicating:
                vda.diskReplicate = {
                    'diskType': DISK_TYPE.BLOCK,
                    'format': 'cow',
                }

            self.assertEqual(vda.threshold_state, BLOCK_THRESHOLD.UNSET)
            # first run: does nothing but set the block thresholds
            testvm.monitor_drives()

            self.assertEqual(vda.threshold_state, BLOCK_THRESHOLD.SET)
            self.assertEqual(dom.thresholds[vda.name], threshold)

    def test_set_new_threshold_when_state_unset_but_fails(self):

        with make_env(events_enabled=True) as (testvm, dom, drives):
            for drive in drives:
                self.assertEqual(drive.threshold_state, BLOCK_THRESHOLD.UNSET)

            # Simulate setBlockThreshold failure
            testvm._dom.errors["setBlockThreshold"] = fake.Error(
                libvirt.VIR_ERR_OPERATION_FAILED, "fake error")

            # first run: does nothing but set the block thresholds
            testvm.monitor_drives()

            for drive in drives:
                self.assertEqual(drive.threshold_state, BLOCK_THRESHOLD.UNSET)

    def test_set_new_threshold_when_state_set(self):
        # Vm.monitor_drives must not pick up drives with
        # threshold_state == SET, so we call
        # Vm.extend_drive_if_needed explictely
        with make_env(events_enabled=True) as (testvm, dom, drives):
            drives[0].threshold_state = BLOCK_THRESHOLD.SET

            extended = testvm.extend_drive_if_needed(drives[0])

            self.assertFalse(extended)

    def test_event_received_before_write_completes(self):
        # QEMU submits an event when write is attempted, so it
        # is possible that at the time we receive the event the
        # the write was not completed yet, or failed, and the
        # volume size is still bellow the threshold.
        # We will not extend the drive, but keep it marked for
        # extension.
        with make_env(events_enabled=True) as (testvm, dom, drives):

            # NOTE: write not yet completed, so the allocation value
            # for the drive must me below than the value reported in
            # the event.
            vda = dom.block_info['/virtio/0']

            alloc = allocation_threshold_for_resize_mb(
                vda, drives[0]) + 1 * MB

            # Simulating block threshold event
            testvm.drive_monitor.on_block_threshold(
                'vda', '/virtio/0', alloc, 1 * MB)

            testvm.monitor_drives()

            # The threshold state is correctly kept as exceeded, so extension
            # will be tried again next cycle.
            self.assertEqual(drives[0].threshold_state,
                             BLOCK_THRESHOLD.EXCEEDED)

    def test_block_threshold_set_failure_after_drive_extended(self):

        with make_env(events_enabled=True) as (testvm, dom, drives):

            # first run: does nothing but set the block thresholds
            testvm.monitor_drives()

            # Simulate write on drive vdb
            vdb = dom.block_info['/virtio/1']

            # The BLOCK_THRESHOLD event contains the highest allocated
            # block...
            alloc = allocation_threshold_for_resize_mb(
                vdb, drives[1]) + 1 * MB

            # ... but we repeat the check in monitor_drives(),
            # so we need to set both locations to the correct value.
            vdb['allocation'] = alloc

            # Simulating block threshold event
            testvm.drive_monitor.on_block_threshold(
                'vdb', '/virtio/1', alloc, 1 * MB)

            # Simulating periodic check
            testvm.monitor_drives()
            self.assertEqual(len(testvm.cif.irs.extensions), 1)

            # Simulate completed extend operation, invoking callback

            # Simulate setBlockThreshold failure
            testvm._dom.errors["setBlockThreshold"] = fake.Error(
                libvirt.VIR_ERR_OPERATION_FAILED, "fake error")

            simulate_extend_callback(testvm.cif.irs, extension_id=0)

            drv = drives[1]
            self.assertEqual(drv.threshold_state, BLOCK_THRESHOLD.UNSET)


class FakeVM(vm.Vm):

    log = logging.getLogger('test')

    def __init__(self, cif, dom, disks):
        self.id = 'drive_monitor_vm'
        self.cif = cif
        self.drive_monitor = drivemonitor.DriveMonitor(self, self.log)
        self._dom = dom
        self._devices = {hwclass.DISK: disks}

        # needed for pause()/cont()

        self._lastStatus = vmstatus.UP
        self._guestCpuRunning = True
        self._custom = {}
        self._confLock = threading.Lock()
        self.conf = {}
        self._guestCpuLock = threading.Lock()
        self._resume_behavior = 'auto_resume'
        self._pause_time = None

    # to reduce the amount of faking needed, we fake those methods
    # which are not relevant to the monitor_drives() flow

    def send_status_event(self, **kwargs):
        pass

    def isMigrating(self):
        return False

    def _update_metadata(self):
        pass


class FakeDomain(object):

    def __init__(self):
        self._state = (libvirt.VIR_DOMAIN_RUNNING, )
        self.block_info = {
            # capacity is random value > 0
            # physical is random value > 0, <= capacity
            '/virtio/0': {
                'capacity': 4 * GB,
                'allocation': 0 * GB,
                'physical': 2 * GB,
            },
            '/virtio/1': {
                'capacity': 2 * GB,
                'allocation': 0 * GB,
                'physical': 1 * GB,
            },
        }
        self.errors = {}
        self.thresholds = {}

    def blockInfo(self, path, flags):
        # TODO: support access by name
        # flags is ignored
        d = self.block_info[path]
        return d['capacity'], d['allocation'], d['physical']

    # The following is needed in the 'pause' flow triggered
    # by the ImprobableResizeRequestError

    def XMLDesc(self, flags):
        return u'<domain/>'

    def suspend(self):
        self._state = (libvirt.VIR_DOMAIN_PAUSED, )

    def resume(self):
        self._state = (libvirt.VIR_DOMAIN_RUNNING, )

    def info(self):
        return self._state

    @maybefail
    def setBlockThreshold(self, dev, threshold):
        self.thresholds[dev] = threshold


class FakeClientIF(fake.ClientIF):

    def notify(self, event_id, params=None):
        pass


class FakeIRS(object):

    def __init__(self):
        self.extensions = []
        self.refreshes = []
        self.volume_sizes = {}

    def sendExtendMsg(self, poolID, volInfo, newSize, func):
        self.extensions.append((poolID, volInfo, newSize, func))

    def refreshVolume(self, domainID, poolID, imageID, volumeID):
        key = (domainID, poolID, imageID, volumeID)
        self.refreshes.append(key)

    def getVolumeSize(self, domainID, poolID, imageID, volumeID):
        # For block storage we "truesize" and "apparentsize" are always
        # the same, they exists only for compatibility with file volumes
        key = (domainID, poolID, imageID, volumeID)
        size = self.volume_sizes[key]
        return response.success(apparentsize=size, truesize=size)


# TODO: factor out this function and its counterpart in vmstorage_test.py
def drive_config(**kw):
    ''' Return drive configuration updated from **kw '''
    conf = {
        'device': 'disk',
        'format': 'raw',
        'iface': 'virtio',
        'index': '0',
        'propagateErrors': 'off',
        'readonly': 'False',
        'shared': 'none',
        'type': 'disk',
    }
    conf.update(kw)
    conf['path'] = '/{iface}/{index}'.format(
        iface=conf['iface'], index=conf['index']
    )
    return conf


def simulate_extend_callback(irs, extension_id):
    poolID, volInfo, newSize, func = irs.extensions[extension_id]
    key = (volInfo['domainID'], volInfo['poolID'],
           volInfo['imageID'], volInfo['volumeID'])
    # Simulate refresh, updating local volume size
    irs.volume_sizes[key] = newSize

    func(volInfo)

    # Calling refreshVolume is critical in this flow.
    # Check this indeed happened.
    if key != irs.refreshes[0]:
        raise AssertionError('Volume %s not refreshed' % key)
