#!/usr/bin/python3
#
# Copyright (C) Citrix Systems Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation; version 2.1 only.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
#
# Pause/unpause tapdisk on the local host

import os
import sys
import XenAPIPlugin
sys.path.append("/opt/xensource/sm/")
import blktap2, util
from lock import Lock
import xs_errors
import XenAPI
import lvmcache

from constants import NS_PREFIX_LVM, VG_PREFIX
from cowutil import getCowUtil
from lvmcowutil import LV_PREFIX, LV_PREFIX_TO_VDI_TYPE, LvmCowUtil
from vditype import VdiType

try:
    from linstorcowutil import LinstorCowUtil
    from linstorvolumemanager import get_controller_uri, LinstorVolumeManager
    LINSTOR_AVAILABLE = True
except ImportError:
    LINSTOR_AVAILABLE = False

TAPDEV_BACKPATH_PFX = "/dev/sm/backend"
TAPDEV_PHYPATH_PFX = "/dev/sm/phy"

def locking(excType, override=True):
    def locking2(op):
        def wrapper(self, *args):
            if self.failfast:
                if not self.lock.acquireNoblock():
                    raise xs_errors.XenError(excType,
                            opterr='VDI already locked')
            else:
                self.lock.acquire()
            try:
                try:
                    ret = op(self, *args)
                except (util.SMException, XenAPI.Failure) as e:
                    util.logException("TAP-PAUSE:%s" % op)
                    msg = str(e)
                    if isinstance(e, util.CommandException):
                        msg = "Command %s failed (%s): %s" % \
                                (e.cmd, e.code, e.reason)
                    if override:
                        raise xs_errors.XenError(excType, opterr=msg)
                    else:
                        raise
                except:
                    util.logException("TAP-PAUSE:%s" % op)
                    raise
            finally:
                self.lock.release()
            return ret
        return wrapper
    return locking2

def _getDevMajor_minor(dev):
    st = os.stat(dev)
    return [os.major(st.st_rdev),os.minor(st.st_rdev)]

def _mkphylink(sr_uuid, vdi_uuid, path):
    sympath = "/dev/sm/phy/%s/%s" % (sr_uuid,vdi_uuid)
    cmd = ['ln', '-sf', path, sympath]
    util.pread2(cmd)
    return path

def tapPause(session, args):
    tap = Tapdisk(session, args)
    return tap.Pause()

def tapUnpause(session, args):
    tap = Tapdisk(session, args)
    return tap.Unpause()

def tapRefresh(session, args):
    tap = Tapdisk(session, args)
    if tap.Pause() != "True":
        return str(False)
    return tap.Unpause()


class Tapdisk:
    def __init__(self, session, args):
        self.sr_uuid = args["sr_uuid"]
        self.vdi_uuid = args["vdi_uuid"]
        # Tells whether the lock must be acquired in a non-blocking manner.
        if 'failfast' in args:
            self.failfast = eval(args['failfast'])
        else:
            self.failfast = False
        self.session = session
        self.path = os.path.join(TAPDEV_BACKPATH_PFX,self.sr_uuid,self.vdi_uuid)
        self.phypath = os.path.join(TAPDEV_PHYPATH_PFX,self.sr_uuid,self.vdi_uuid)
        self.lock = Lock("vdi", self.vdi_uuid)
        self.realpath = None
        self.vdi_type = None
        self.secondary = None
        if "secondary" in args:
            self.secondary = args["secondary"]
        self.activate_parents = False
        if args.get("activate_parents") == "true":
            self.activate_parents = True

    def _pathRefresh(self):
        # LVM rename check
        try:
            realpath = os.readlink(self.phypath)
        except OSError as e:
            util.SMlog("Phypath %s does not exist" % self.phypath)
            return
        util.SMlog("Realpath: %s" % realpath)
        if realpath.startswith("/dev/VG_XenStorage-") and \
                not os.path.exists(realpath):
            util.SMlog("Path inconsistent")
            pfx = "/dev/VG_XenStorage-%s/" % self.sr_uuid
            for ty in LV_PREFIX.values():
                p = pfx + ty + self.vdi_uuid
                util.SMlog("Testing path: %s" % p)
                if os.path.exists(p):
                    _mkphylink(self.sr_uuid, self.vdi_uuid, p)
                    self.realpath = p
                    self.vdi_type = LV_PREFIX_TO_VDI_TYPE[ty]
        elif realpath.startswith('/dev/drbd/by-res/xcp-volume-'):
            if not LINSTOR_AVAILABLE:
                raise util.SMException(
                    'Can\'t refresh tapdisk: LINSTOR libraries are missing'
                )

            # We must always recreate the symlink to ensure we have
            # the right info. Why? Because if the volume UUID is changed in
            # LINSTOR the symlink is not directly updated. When live leaf
            # coalesce is executed we have these steps:
            # "A" -> "OLD_A"
            # "B" -> "A"
            # Without symlink update the previous "A" path is reused instead of
            # "B" path. Note: "A", "B" and "OLD_A" are UUIDs.
            session = self.session

            host_ref = util.get_this_host_ref(session)
            sr_ref = session.xenapi.SR.get_by_uuid(self.sr_uuid)

            pbd = util.find_my_pbd(session, host_ref, sr_ref)
            if pbd is None:
                raise util.SMException('Failed to find PBD')

            dconf = session.xenapi.PBD.get_device_config(pbd)
            group_name = dconf['group-name']

            linstor = LinstorVolumeManager(
                get_controller_uri(),
                group_name,
                logger=util.SMlog
            )

            from srmetadata import VDI_TYPE_TAG
            self.vdi_type = linstor.get_volume_metadata(self.vdi_uuid)[VDI_TYPE_TAG]
            device_path = LinstorCowUtil(session, linstor, self.vdi_type).create_chain_paths(self.vdi_uuid)

            if realpath != device_path:
                util.SMlog(
                    'Update LINSTOR PhyLink (previous={}, current={})'
                    .format(realpath, device_path)
                )
                os.unlink(self.phypath)
                _mkphylink(self.sr_uuid, self.vdi_uuid, device_path)
            self.realpath = device_path

    @locking("VDIUnavailable")
    def Pause(self):
        util.SMlog("Pause for %s" % self.vdi_uuid)
        if not os.path.exists(self.path):
            util.SMlog("No %s: nothing to pause" % self.path)
            return str(True)
        self.major, self.minor = _getDevMajor_minor(self.path)
        if self.major != blktap2.Tapdisk.major():
            util.SMlog("Non-tap major number: %d" % self.major)
            return str(False)
        tapargs = {"minor":self.minor}
        util.SMlog("Calling tap pause with minor %d" % self.minor)
        tapdisk = blktap2.Tapdisk.get(**tapargs)
        tapdisk.pause()
        return str(True)

    @locking("VDIUnavailable")
    def Unpause(self):
        util.SMlog("Unpause for %s" % self.vdi_uuid)
        if not os.path.exists(self.path):
            util.SMlog("No %s: nothing to unpause" % self.path)
            return str(True)
        self._pathRefresh()
        self.major, self.minor = _getDevMajor_minor(self.path)
        if self.major != blktap2.Tapdisk.major():
            util.SMlog("Non-tap major number: %d" % self.major)
            return str(False)

        import VDI
        vdi = VDI.VDI.from_uuid(self.session, self.vdi_uuid)

        if self.activate_parents:
            util.SMlog("Activating parents of %s" % self.vdi_uuid)
            vg_name = VG_PREFIX + self.sr_uuid
            ns = NS_PREFIX_LVM + self.sr_uuid
            lvm_cache = lvmcache.LVMCache(vg_name)
            lv_name = LV_PREFIX[vdi.vdi_type] + self.vdi_uuid
            vdi_list = getCowUtil(vdi.vdi_type).getParentChain(lv_name, LvmCowUtil.extractUuid, vg_name)
            for uuid, lv_name in vdi_list.items():
                if uuid == self.vdi_uuid:
                    continue
                lvm_cache.activate(ns, uuid, lv_name, False)

        # Check if CBT is enabled on disk we are about to unpause
        if vdi._get_blocktracking_status():
            logname = vdi._get_cbt_logname(self.vdi_uuid)
            # Ensure CBT log file associated with virtual disk
            # is activated before use
            vdi._activate_cbt_log(logname)
            self.cbtlog = vdi._get_cbt_logpath(self.vdi_uuid)
        else:
            self.cbtlog = None

        tapargs = {"minor":self.minor}
        util.SMlog("Calling tap unpause with minor %d" % self.minor)
        tapdisk = blktap2.Tapdisk.get(**tapargs)
        tapdisk.unpause(self.vdi_type, self.realpath, self.secondary, self.cbtlog)
        return str(True)


if __name__ == "__main__":
    XenAPIPlugin.dispatch({"pause": tapPause,
                           "unpause": tapUnpause,
                           "refresh": tapRefresh})
