#!/usr/bin/env python

import datetime
import os
import re
import logging
import logging.handlers
import stat
import subprocess
import sys

try:
    import configparser
except ImportError:
    import ConfigParser as configparser

from string import Template

import zsnaplib

LOGGER = 'zsnapper'

RET_CODES = {
        'SUCCESS': 0,
        'ERROR': 1,
        'FAILED': 2
        }

DEFAULT_CONFIG = {
        'snapshot_interval': None,
        'custom_keep_interval': None,
        'weed_enable': False,
        'keep_yearly': 0,
        'keep_monthly': 0,
        'keep_weekly': 0,
        'keep_daily': 0,
        'keep_hourly': 0,
        'keep_30min': 0,
        'keep_15min': 0,
        'keep_5min': 0,
        'keep_1min': 0,
        'keep_custom': 0,
        'source_zfs_cmd': '/sbin/zfs',
        'source_test_cmd': None,
        'target_fs': None,
        'target_zfs_cmd': '/sbin/zfs',
        'target_test_cmd': None,
        'send_flags': '',
        'recv_flags': '',
        'send_enable': False,
        }

timedelta_regex = re.compile('([0-9]+)([dhm])')

def fs_is_available(conf):
    log = logging.getLogger(LOGGER)
    for test in ('source_test_cmd', 'target_test_cmd'):
        if not conf[test]:
            continue
        cmdstr = Template(conf[test]).safe_substitute(conf)
        cmd = cmdstr.split()
        proc = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)
        (out, err) = proc.communicate()
        log.info('Healthcheck "{}" returned {}'.format(cmdstr, proc.returncode))
        if proc.returncode != 0:
            return False
    return True


def str_to_timedelta(deltastr):

    delta = datetime.timedelta()
    for match in timedelta_regex.finditer(deltastr):
        if match.group(2) == 'd':
            delta += datetime.timedelta(days=int(match.group(1)))
        elif match.group(2) == 'h':
            delta += datetime.timedelta(hours=int(match.group(1)))
        elif match.group(2) == 'm':
            delta += datetime.timedelta(minutes=int(match.group(1)))
    return delta

def get_config_for_fs(fs, config):
    if '@' in fs:
        fs, remote = fs.split('@', 1)
    else:
        remote = None
    fs_config = DEFAULT_CONFIG.copy()
    fs_build = ''
    for fs_part in fs.split('/'):
        fs_build += fs_part
        if remote:
            section = "{}@{}".format(fs_build, remote)
        else:
            section = fs_build
        if section in config:
            fs_config.update(config[section])
        if fs_build == fs:
            break
        fs_build += '/'

    fs_config['source_fs'] = fs
    return fs_config


def do_snapshots(fslist, snapshots, config):
    failed_snapshots = set()
    now = datetime.datetime.now()
    log = logging.getLogger(LOGGER)

    for fs in fslist:
        conf = get_config_for_fs(fs, config)
        source_fs = conf['source_fs']
        if not conf['snapshot_interval']:
            continue

        zfs_cmd = Template(conf['source_zfs_cmd']).safe_substitute(conf)
        zfs_cmd = zfs_cmd.split()
        interval = str_to_timedelta(conf['snapshot_interval'])
        if source_fs in snapshots and snapshots[source_fs] and snapshots[source_fs][0]:
            last_snap = snapshots[source_fs][0]
        else:
            last_snap = datetime.datetime.min
        if interval > datetime.timedelta() and last_snap+interval < now:
            try:
                zsnaplib.create_snapshot(source_fs, zfs_cmd)
                log.info('{} snapshot created using {}'.format(fs, zfs_cmd))
            except zsnaplib.ZFSSnapshotError as e:
                log.warning(e)
                failed_snapshots.add(fs)
    return failed_snapshots

def get_remote_sources(config):
    ret = {}
    for section in config.sections():
        if '@' in section and 'source_zfs_cmd' in config[section]:
            fs, remote = section.split('@', 1)
            conf = get_config_for_fs(section, config)
            if not fs_is_available(conf):
                continue
            source_zfs_cmd = Template(config[section]['source_zfs_cmd']).safe_substitute(config[section])
            source_zfs_cmd = source_zfs_cmd.split()
            ret[remote] = source_zfs_cmd
    return ret


def send_snapshots(fslist, snapshots, config):
    failed_snapshots = set()
    remote_hosts = {}
    remote_targets = {}
    log = logging.getLogger(LOGGER)
    for fs in fslist:
        conf = get_config_for_fs(fs, config)
        remote_snapshots = None
        if not conf['send_enable']:
            continue
        if not fs_is_available(conf):
            failed_snapshots.add(fs)
            continue

        repl_mode = conf['send_enable']
        target_fs = conf['target_fs']
        source_fs = conf['source_fs']
        send_opts = []
        recv_opts = []
        if conf['send_flags']:
            send_opts = conf['send_flags'].split()
        if conf['recv_flags']:
            recv_opts = conf['recv_flags'].split()

        rel_local = [k for k, v in remote_targets.items() if v == target_fs]
        if rel_local:
            rel_local = rel_local[0]
            rel_fs = source_fs[len(rel_local):]
            target_fs = '{}{}'.format(target_fs, rel_fs)
        remote_targets[source_fs] = target_fs

        # Figure out the state of remote zfs
        target_zfs_cmd = Template(conf['target_zfs_cmd']).safe_substitute(conf)
        target_zfs_cmd = target_zfs_cmd.split()
        source_zfs_cmd = Template(conf['source_zfs_cmd']).safe_substitute(conf)
        source_zfs_cmd = source_zfs_cmd.split()
        # to avoid running too many commands on remote host, save result if we
        # know which host we're working with.
        if 'target_host' in conf:
            if conf['target_host'] in remote_hosts:
                remote_snapshots = remote_hosts[conf['target_host']]
            else:
                remote_snapshots = zsnaplib.get_snapshots(target_zfs_cmd)
                remote_hosts[conf['target_host']] = remote_snapshots
        if not remote_snapshots:
            remote_snapshots = zsnaplib.get_snapshots(target_zfs_cmd)

        if target_fs not in remote_snapshots:
            # Remote FS doesn't exist, send a new copy
            log.info('{} sending base copy to {}'.format(fs, ' '.join(target_zfs_cmd)))
            # oldest snapshot is base_snap if repl_mode != latest
            base_snap = snapshots[source_fs][-1]
            if repl_mode == 'latest':
                base_snap = snapshots[source_fs][0]
            try:
                zsnaplib.send_snapshot(
                        source_fs, 
                        base_snap, 
                        target_zfs_cmd,
                        target_fs,
                        source_zfs_cmd,
                        send_opts=send_opts,
                        recv_opts=recv_opts)
                log.info('{} base copy sent'.format(fs))
            except zsnaplib.ZFSSnapshotError as e:
                failed_snapshots.add(fs)
                log.warning(e)
                continue
            remote_snapshots[target_fs] = [base_snap]

        # Remote FS now exists, one way or another find last common snapshot
        last_remote = None
        for remote_snap in remote_snapshots[target_fs]:
            if remote_snap in snapshots[source_fs]:
                last_remote = remote_snap
                break
        if not last_remote:
            failed_snapshots.add(fs)
            log.warning('{}: No common snapshot local and remote, you need to create a new base copy!'.format(fs))
            continue
        last_local = snapshots[source_fs][0]
        if last_remote == last_local:
            log.info("{} snapshot from {} is already present at target".format(fs, last_local))
            continue

        log.info('{} incremental {} -> {}, remote is {}'.format(fs, last_remote, snapshots[source_fs][0], ' '.join(target_zfs_cmd)))
        try:
            zsnaplib.send_snapshot(
                    source_fs, 
                    snapshots[source_fs][0], 
                    target_zfs_cmd, 
                    target_fs,
                    source_zfs_cmd,
                    send_opts=send_opts,
                    recv_opts=recv_opts,
                    repl_from=last_remote,
                    repl_mode=repl_mode)
            log.info('{} successfully sent to remote'.format(fs))
        except zsnaplib.ZFSSnapshotError as e:
            log.warning(e)
            failed_snapshots.add(fs)
    return failed_snapshots

def weed_snapshots(fslist, snapshots, config, failed_snapshots):
    log = logging.getLogger(LOGGER)
    for fs in fslist:
        conf = get_config_for_fs(fs, config)
        source_fs = conf['source_fs']
        if fs in failed_snapshots:
            log.info("Not weeding {} because of snapshot creation/send failure".format(fs))
            continue
        if source_fs not in snapshots:
            continue
        if not conf['weed_enable'] or conf['weed_enable'].lower() in ('false', 'no'):
            continue

        kwargs = {k: int(v) for k, v in conf.items() if k in [
                'keep_custom',
                'keep_yearly',
                'keep_monthly',
                'keep_weekly',
                'keep_daily',
                'keep_hourly',
                'keep_30min',
                'keep_15min',
                'keep_5min',
                'keep_1min']}
        if conf['custom_keep_interval']:
            kwargs['custom_keep_interval'] = str_to_timedelta(conf['custom_keep_interval'])

        zfs_cmd = Template(conf['source_zfs_cmd']).safe_substitute(conf)
        zfs_cmd = zfs_cmd.split()

        zsnaplib.weed_snapshots(
                fs,
                # never remove the latest snapshot 
                snapshots[source_fs][1:],
                zfs_cmd,
                **kwargs)


def main():
    config = configparser.ConfigParser()
    config.read(['/usr/local/etc/zsnapper.ini', '/etc/zsnapper.ini'])

    ret = RET_CODES['SUCCESS']
    log = logging.getLogger(LOGGER)

    # guess the local zfs command, this is pretty ugly...
    zfs_cmd_conf = DEFAULT_CONFIG
    for section in config.sections():
        if '@' not in section:
            if 'source_zfs_cmd' in config[section]:
                zfs_cmd_conf = get_config_for_fs(section, config)
    local_zfs_cmd = Template(zfs_cmd_conf['source_zfs_cmd']).safe_substitute(zfs_cmd_conf)
    local_zfs_cmd = local_zfs_cmd.split()
    
    fslist = sorted(zsnaplib.get_filesystems(local_zfs_cmd))
    snapshots = zsnaplib.get_snapshots(local_zfs_cmd)
    
    failed_snapshots = do_snapshots(fslist, snapshots, config)
    if failed_snapshots:
        ret = RET_CODES['ERROR']

    lockfile = '/tmp/zsnapper.pid'
    # This loop should run at most twice
    while True:
        try:
            lockfd = os.open(lockfile, os.O_CREAT|os.O_EXCL|os.O_WRONLY, mode=0o640)
            os.write(lockfd, "{}".format(os.getpid()).encode('utf-8'))
            os.close(lockfd)
            break
        except OSError:
            pass

        # lock file exists, check if the pid seems valid
        with open(lockfile, 'r') as f:
            pid = f.read()
        try:
            pid = int(pid)
            os.kill(pid, 0)
            # If we got here the lock is owned by an existing pid
            log.info('Previous run is not completed yet, will not send or weed snapshots')
            return ret
        except OSError:
            # pid is not running, forcing unlock
            os.remove(lockfile)
        except ValueError:
            log.error('lockfile {} exists but does not seem to contain a pid. Will not continue'.format(lockfile))
            return RET_CODES['FAILED']

    # create any remote snapshots
    remotes = get_remote_sources(config)
    remote_fs = {}
    remote_snapshots = {}
    failed_remote_snapshots = {}
    for remote, zfs_cmd in remotes.items():
        try:
            remote_fs[remote] = sorted(zsnaplib.get_filesystems(zfs_cmd))
            remote_snapshots[remote] = zsnaplib.get_snapshots(zfs_cmd)
            failed_remote_snapshots[remote] = do_snapshots(
                    ["{}@{}".format(x, remote) for x in remote_fs[remote]],
                    remote_snapshots[remote],
                    config)
        except zsnaplib.ZFSSnapshotError as e:
            if remote in remote_fs:
                del remote_fs[remote]
            if remote in remote_snapshots:
                del remote_snapshots[remote]
            if remote in remotes:
                del remotes[remote]
            log.warning("Failed to snapshot on {}: {}".format(remote, e))
            ret = RET_CODES['ERROR']

    for remote, filesystems in failed_remote_snapshots.items():
        for fs in filesystems:
            log.warning("Failed to snapshot {} on {}".format(fs, remote))

    # reload all snapshots so we get our new snapshots here
    for remote, zfs_cmd in remotes.items():
        try:
            if remote in remote_snapshots:
                remote_snapshots[remote] = zsnaplib.get_snapshots(zfs_cmd)
        except zsnaplib.ZFSSnapshotError as e:
            del remote_snapshots[remote]
            log.warning("Could not refresh snapshots on {}: {}".format(remote, e))
    snapshots = zsnaplib.get_snapshots(local_zfs_cmd)

    failed_send = send_snapshots(fslist, snapshots, config)
    failed_snapshots.update(failed_send)
    for remote in remotes.keys():
        failed_send = send_snapshots(
                ["{}@{}".format(x, remote) for x in remote_fs[remote]],
                remote_snapshots[remote],
                config)
        if failed_send:
            ret = RET_CODES['ERROR']
        failed_snapshots.update(failed_send)

    weed_snapshots(fslist, snapshots, config, failed_snapshots)

    for remote in remotes.keys():
        weed_snapshots(
                ["{}@{}".format(x, remote) for x in remote_fs[remote]],
                remote_snapshots[remote],
                config,
                failed_snapshots)

    os.remove(lockfile)
    return ret

if __name__ == '__main__':
    log = logging.getLogger(LOGGER)
    log.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    handler.setLevel(logging.WARNING)
    log.addHandler(handler)

    handler = None
    for logsocket in ('/var/run/log', '/dev/log'):
        try:
            mode = os.stat(logsocket).st_mode
        except FileNotFoundError:
            continue

        if stat.S_ISSOCK(mode):
            handler = logging.handlers.SysLogHandler(address=logsocket)
            formatter = logging.Formatter(fmt='zsnapper[%(process)s] %(message)s')
            handler.setFormatter(formatter)
            handler.setLevel(logging.INFO)
            log.addHandler(handler)
            break

    if not handler:
        log.warning('No syslog socket found, will not log to syslog')

    sys.exit(main())
