#!/bin/python3
import os
import sys
import logging
import argparse
import subprocess

from tempfile import NamedTemporaryFile

import evdev

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
virsh_logger = logging.getLogger('virsh')

usb_xmls = []

def listen():
    logger.info('Listening for events on %s', args.evdev_kbd)
    device = evdev.InputDevice(args.evdev_kbd)
    left_ctrl = False
    right_ctrl = False
    triggered = False
    grabbed = False
    for event in device.read_loop():
        logger.debug('Got event: %s', event)

        if grabbed and event.code != 0:
            grabbed = False
            on_release()

        if event.code == evdev.ecodes.ecodes['KEY_LEFTCTRL']:
            left_ctrl = event.value != 0
        elif event.code == evdev.ecodes.ecodes['KEY_RIGHTCTRL']:
            right_ctrl = event.value != 0

        if left_ctrl and right_ctrl:
            triggered = True
        elif triggered and not left_ctrl and not right_ctrl:
            triggered = False
            grabbed = True
            on_grab()

def on_grab():
    logger.info('Grabbed')
    for usb_xml in usb_xmls:
        invoke_virsh('attach-device', args.domain, usb_xml.name)

def on_release():
    logger.info('Released')
    for usb_xml in usb_xmls:
        invoke_virsh('detach-device', args.domain, usb_xml.name)

def invoke_virsh(subcommand, domain, xml_name):
    proc = subprocess.Popen(('virsh', subcommand, domain, xml_name),
                            encoding='utf8',
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
    stdout, stderr = proc.communicate()
    for line in stdout.strip().splitlines():
        virsh_logger.info(line)
    for line in stderr.strip().splitlines():
        virsh_logger.error(line)

if __name__ == '__main__':
    if os.geteuid() != 0:
        logger.error('This script must be run as root')
        sys.exit(1)

    parser = argparse.ArgumentParser()
    parser.add_argument('--evdev-kbd',
                        help='The evdev keyboard device to read eg: /dev/input/...',
                        required=True)
    parser.add_argument('domain',
                        help='The libvirt domain to attach the usb devices to')
    parser.add_argument('usb_ids',
                        help='The ids of the usb devices to attach, formatted as idVendor:idProduct',
                        nargs='+')
    args = parser.parse_args()
    logger.debug(args)

    try:
        for id_vendor, id_product in [x.split(':') for x in args.usb_ids]:
            usb_xml = NamedTemporaryFile('w+')
            usb_xml.write('<hostdev mode="subsystem" type="usb" managed="yes">\n')
            usb_xml.write(' <source>\n')
            usb_xml.write('  <vendor id="0x%s"/>\n' % id_vendor)
            usb_xml.write('  <product id="0x%s"/>\n' % id_product)
            usb_xml.write('</source>\n')
            usb_xml.write('</hostdev>')

            usb_xml.seek(0)
            logger.debug('Created usb device xml:\n%s', usb_xml.read())

            usb_xmls.append(usb_xml)
        listen()
    finally:
        for usb_xml in usb_xmls:
            usb_xml.close()