#! /usr/bin/python3
#
# Copyright (C) 2022 Citrix Systems R&D Ltd.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation; version 2.1 only. with the special
# exception on linking described in file LICENSE.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.

import os
import stat
import socket
import sys
import pwd
import grp
import subprocess
import ctypes
import ctypes.util
import struct
from resource import getrlimit, RLIMIT_CORE, RLIMIT_FSIZE, setrlimit
from urllib.parse import urlparse

CLONE_NEWNS  = 0x00020000 # mount namespace
CLONE_NEWNET = 0x40000000 # network namespace
CLONE_NEWIPC = 0x08000000 # IPC namespace

SWTPM_NVSTORE_LINEAR_MAGIC = 0x737774706d6c696e

def unshare(flags):
    libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True)
    unshare_prototype = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, use_errno=True)
    unshare = unshare_prototype(('unshare', libc))
    ret = unshare(flags)
    if ret < 0:
        raise OSError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))

def enable_core_dumps():

    limit = 64 * 1024 * 1024
    oldlimits = getrlimit(RLIMIT_CORE)
    hardlimit = oldlimits[1]
    if limit > hardlimit:
        hardlimit = limit
        setrlimit(RLIMIT_CORE, (limit, hardlimit))
        return limit

def prepare_exec():
    """Set up the execution environment for SWTPM."""

    core_dump_limit = enable_core_dumps()
    print("core dump limit: %d" % core_dump_limit)

    limit = 256 * 1024
    setrlimit(RLIMIT_FSIZE, (limit, limit))

    flags = CLONE_NEWNS | CLONE_NEWIPC | CLONE_NEWNET
    unshare(flags)

    sys.stdout.flush()
    sys.stderr.flush()

def make_socket(fname):

    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    print("Binding socket to %s" % fname)
    sock.bind(fname)
    sock.listen(1)

    return sock

def check_state_needs_init(fname):
    if fname is None:
        return False
    print("Checking TPM state %s" % fname)

    if not os.path.exists(fname):
        return True

    mode = os.stat(fname).st_mode

    if not stat.S_ISBLK(mode):
        return False

    # Check if block device has non-zero header
    with open(fname, "r") as file:
        hdr = file.read(8)
    int_hdr = struct.unpack("<Q", hdr)[0]

    if len(hdr) == 8 and int_hdr == SWTPM_NVSTORE_LINEAR_MAGIC:
        return False
    return True


def main(argv):
    print("Arguments: %s" % " ".join(argv[1:]))

    if len(argv) < 5:
        sys.exit("Not enough arguments.")

    domid = int(argv[1])
    tpm_dir = argv[2]
    tpm_state = argv[3]
    needs_init = argv[4] == "true"
    tpm_path = tpm_dir
    depriv = True

    for arg in argv[5:]:
        if arg == "--priv":
            depriv = False
        else:
            sys.exit("Unknown option %s\n" % arg)

    parsed = urlparse(tpm_state)
    tpm_state_scheme = parsed.scheme
    tpm_state_file = parsed.path.lstrip('/')
    tpm_file = os.path.join(tpm_dir, tpm_state_file)
    tpm_args = []

    # ensure that paths work regardless of whether we are in the chroot or not:
    # we'll only use relative paths
    os.chdir(tpm_dir)

    if (tpm_state_scheme == "dir"):
        tpm_state_file = "tpm2-00.permall"
        tpm_file = os.path.join(tpm_dir, tpm_state_file)
        pass
    elif (tpm_state_scheme == "file"):
        pass
    elif (tpm_state_scheme == "http" or tpm_state_scheme == "unix+http"):
        tpm_state_file = None
        tpm_args = ["--seccomp","action=none"] # TODO: due to curl syscalls
        tpm_file = None
    else:
        sys.exit("Unknown state scheme  %s\n" % tpm_state_scheme)

    tpm_env = dict(os.environ)
    tpm_env["LD_LIBRARY_PATH"] = "/usr/lib:"

    if needs_init or check_state_needs_init(tpm_state_file):
        if tpm_file is None:
            sys.exit("Unsupported scheme for TPM initialization: %s" % tpm_state_scheme)
        print('Initial manufacture')
        tpm_exe = '/usr/bin/swtpm_setup'
        tpm_args = ["swtpm_setup", "--tpm2", "--tpm-state", tpm_dir, "--createek", "--create-ek-cert", "--create-platform-cert", "--lock-nvram", "--not-overwrite"]
        print('Running %s' % tpm_args)
        prepare_exec()
        os.execve(tpm_exe, tpm_args, tpm_env)

    tpm_exe = '/usr/bin/swtpm'
    uid = pwd.getpwnam('swtpm_base').pw_uid + domid

    if depriv:
        tpm_args = ["--chroot", tpm_dir,
                    "--runas", str(uid)]
        try:
            dev_dir = os.path.join(tpm_dir, "dev")
            if not os.path.isdir(dev_dir):
                os.mkdir(dev_dir)

            urandom = os.path.join(dev_dir, "urandom")
            if not os.path.exists(urandom):
                os.mknod(urandom, 0o666 | stat.S_IFCHR, os.makedev(1, 9))

            if os.path.exists(os.path.join(tpm_dir, ".lock")):
                os.chown(os.path.join(tpm_dir, ".lock"), uid, uid)
            if tpm_file:
                os.chown(tpm_file, uid, uid)

            if (tpm_state_scheme != "file"):
                os.chown(tpm_dir, uid, uid)
            else:
                os.chmod(tpm_dir, 0o750)
                os.chown(tpm_dir, 0 , uid)

        except OSError as error:
            print(error)
            return

        tpm_path = '/'

    swtpm_sock = os.path.join(tpm_dir, "swtpm-sock")
    swtpm_pid = os.path.join(tpm_path, "swtpm-%d.pid" % domid)
    sock = make_socket(swtpm_sock)

    # the PID file is taken by the toolstack as signal that the socket
    # is ready
    swtpm_pid_full = os.path.join(tpm_dir, "swtpm-%d.pid" % domid)
    open(swtpm_pid_full, 'wb').close()
    os.chown(swtpm_pid_full, uid, uid)

    if (tpm_state_scheme == "dir"):
        state_param = "dir=%s" % tpm_path
    else:
        state_param = "backend-uri=%s" % tpm_state

    tpm_args = ["swtpm-%d" % domid, "socket",
               "--tpm2",
               "--tpmstate", state_param,
               "--ctrl", "type=unixio,fd=%i" % sock.fileno(),
               "--log", "level=1",
               "--pid", "file=%s" % swtpm_pid,
               "-t"] + tpm_args


    print("Exec: %s %s" % (tpm_exe, " ".join(tpm_args)))
    # by default sockets are CLOEXEC
    os.set_inheritable(sock.fileno(), True)
    prepare_exec()
    os.execve(tpm_exe, tpm_args, tpm_env)

if __name__ == '__main__':
    raise SystemExit(main(sys.argv))
