#!/usr/bin/env python3
"""
Script to filter out unwanted tables from MySQL/MariaDB SQL dumps (INSERT-only format).
Removes entire table sections including LOCK/UNLOCK and INSERT statements.
Also renames tables and columns in INSERT statements with column lists.
"""

import sys
import re
import argparse

TABLES_TO_REMOVE = [
    'nagiosna_sessions',
    'nagiosna_Reports',
    'nagiosna_cmdsubsys',
    'nagiosna_hostname_cache',
    'nagiosna_login_attempts',
    'nagiosna_path_nodes',
    'nagiosna_path_status',
    'nagiosna_usermeta',
    'nagiosna_migrations',
    'nagiosna_Queries',
    'nagiosna_SourcesViewsLinker',
    'nagiosna_ChecksCmdLinker',
    'nagiosna_ChecksHSALinker',
    'nagiosna_ChecksTrapsLinker',
    'nagiosna_ChecksUsersLinker',
    'nagiosna_groups',
    'nagiosna_users_groups',
]

TABLE_RENAMES = {
    'nagiosna_Commands': 'commands',
    'nagiosna_HostServiceAssociations': 'service_hostnames',
    'nagiosna_NagiosServers': 'nagios_servers',
    'nagiosna_SGLinker': 'source_group_linker',
    'nagiosna_SourceGroups': 'source_groups',
    'nagiosna_Sources': 'sources',
    'nagiosna_TrapReceivers': 'snmp_receivers',
    'nagiosna_auth_servers': 'auth_servers',
    'nagiosna_cf_options': 'config_options',
    'nagiosna_Views': 'views',
    'nagiosna_users': 'users',
    'nagiosna_Checks': 'checks',
}

COLUMN_RENAMES = {
    'nagios_servers': {
        'nrdp': 'nrdp_url',
        'token': 'nrdp_token'
    },
    'service_hostnames': {
        'serverid': 'server_id'
    },
    'source_group_linker': {
        'gid': 'source_group_id',
        'sid': 'source_id'
    },
    'source_groups': {
        'gid': 'id',
    },
    'sources': {
        'sid': 'id',
        'addresses': 'description',
    },
    'snmp_receivers': {
        'tid': 'id',
        'community': 'community_string',
        'authlevel': 'auth_level',
        'privprotocol': 'priv_protocol',
        'privpassword': 'priv_password',
        'authprotocol': 'auth_protocol',
        'authpassword': 'auth_password'
    },
    'views': {
        'vid': 'id'
    },
    'checks': {
        'cid': 'id',
        'rawquery': 'raw_query',
        'lastval': 'last_val',
        'lastrun': 'last_run',
        'lastcode': 'last_code',
        'laststdout': 'last_stdout'
    }
}

SQL_REMOVAL = {
    'config_options': {'id', 'resolve_hosts', 'resolve_host_graphs', 'cache_time', 'rel_map_max'},
    'users': {'salt', 'activation_code', 'forgotten_password_code', 'forgotten_password_time', 'remember_code', 'ip_address'},
    'checks': {'sid', 'gid', 'vid', 'pid', 'aberrant'},
    'sources': {'disable_abnormal'},
}

# Tables that should use INSERT ... ON DUPLICATE KEY UPDATE
# Format: table_name -> {unique_column: [columns_to_update]}
UPSERT_TABLES = {
    'config_options': {
        'unique_column': 'name',
        'update_columns': ['value', 'modified', 'created_by']  # Only update these on duplicate name
    }
}


def rename_table_in_line(line, old_table, new_table):
    """Rename table references in a line."""
    line = line.replace(f'`{old_table}`', f'`{new_table}`')
    line = re.sub(rf'\b{re.escape(old_table)}\b', new_table, line)
    return line


def parse_insert_statement(line, table_name):
    """
    Parse an INSERT statement and return columns and values.
    
    Args:
        line: The INSERT statement line
        table_name: Expected table name (for validation)
    
    Returns:
        tuple: (columns_list, values_tuples_list) or (None, None) if parsing fails
    """
    # Match INSERT INTO table
    insert_match = re.match(rf'^INSERT INTO `{re.escape(table_name)}` \(([^)]+)\) VALUES (.+);$', line, re.DOTALL)
    if not insert_match:
        return None, None
    
    columns_str = insert_match.group(1)
    values_str = insert_match.group(2)
    
    # Parse columns
    columns = [col.strip() for col in columns_str.split(',')]
    
    # Parse each value tuple
    value_tuples = re.findall(r'\([^)]+\)', values_str)
    parsed_tuples = []
    
    for value_tuple in value_tuples:
        values = parse_value_tuple(value_tuple)
        parsed_tuples.append(values)
    
    return columns, parsed_tuples


def parse_value_tuple(value_tuple):
    """
    Parse a single value tuple from SQL INSERT statement.
    Handles quotes, escapes, and commas properly.
    
    Args:
        value_tuple: String like "(value1, 'value2', value3)"
    
    Returns:
        list: List of parsed values
    """
    values_content = value_tuple[1:-1]  # Remove parentheses
    values = []
    current_value = ''
    in_quote = False
    quote_char = None
    escape_next = False
    
    for char in values_content:
        if escape_next:
            current_value += char
            escape_next = False
            continue
        if char == '\\':
            current_value += char
            escape_next = True
            continue
        if char in ("'", '"') and not in_quote:
            in_quote = True
            quote_char = char
        elif char == quote_char and in_quote:
            in_quote = False
            quote_char = None
        elif char == ',' and not in_quote:
            values.append(current_value.strip())
            current_value = ''
            continue
        current_value += char
    
    if current_value:
        values.append(current_value.strip())
    
    return values


def build_insert_statement(table_name, columns, value_tuples, use_upsert=False):
    """
    Build an INSERT statement from columns and values.
    
    Args:
        table_name: Name of the table
        columns: List of column names
        value_tuples: List of value lists
        use_upsert: If True, use INSERT ... ON DUPLICATE KEY UPDATE
    
    Returns:
        str: Complete INSERT statement
    """
    columns_str = ','.join(columns)
    values_list = []
    
    for values in value_tuples:
        values_list.append('(' + ','.join(values) + ')')
    
    values_str = ','.join(values_list)
    
    if use_upsert:
        # Generate ON DUPLICATE KEY UPDATE clause based on configuration
        upsert_config = UPSERT_TABLES.get(table_name, {})
        if isinstance(upsert_config, dict) and 'update_columns' in upsert_config:
            # Specific column updates
            update_columns = upsert_config['update_columns']
            update_clauses = []
            for col in columns:
                col_name = col.strip('`')
                if col_name in update_columns:
                    update_clauses.append(f"{col}=VALUES({col})")
        else:
            # Fallback to old behavior - update all non-id columns
            update_clauses = []
            for col in columns:
                col_name = col.strip('`')
                if col_name not in ['id']:  # Skip auto-increment columns
                    update_clauses.append(f"{col}=VALUES({col})")
        
        if update_clauses:
            update_str = ', '.join(update_clauses)
            return f"INSERT INTO `{table_name}` ({columns_str}) VALUES {values_str} ON DUPLICATE KEY UPDATE {update_str};\n"
    
    return f"INSERT INTO `{table_name}` ({columns_str}) VALUES {values_str};\n"


def remove_columns_from_insert(line, table_name):
    """
    Remove one or more columns and their values from INSERT statements.
    Works for tables defined in SQL_REMOVAL.
    Preserves quotes, escaped characters, and commas inside values.
    """
    if not table_name or table_name not in SQL_REMOVAL:
        return line
    
    # Parse the INSERT statement
    columns, value_tuples = parse_insert_statement(line, table_name)
    if columns is None:
        return line
    
    # Columns to remove
    remove_cols = SQL_REMOVAL.get(table_name, set())
    
    # Determine which indexes correspond to columns we're removing
    remove_indexes = [
        i for i, col in enumerate(columns)
        if col.strip('`') in remove_cols
    ]
    if not remove_indexes:
        return line
    
    # Remove columns by index (in reverse to avoid shifting)
    new_columns = columns.copy()
    for idx in sorted(remove_indexes, reverse=True):
        new_columns.pop(idx)
    
    # Remove corresponding values from each tuple
    new_value_tuples = []
    for values in value_tuples:
        new_values = values.copy()
        for idx in sorted(remove_indexes, reverse=True):
            if len(new_values) > idx:
                new_values.pop(idx)
        new_value_tuples.append(new_values)
    
    use_upsert = table_name in UPSERT_TABLES
    return build_insert_statement(table_name, new_columns, new_value_tuples, use_upsert)


def rename_columns_in_insert(line, table_name):
    """
    Rename column names in INSERT INTO statements.
    Handles format: INSERT INTO `table` (`col1`, `col2`, ...) VALUES ...
    """
    if not table_name or table_name not in COLUMN_RENAMES:
        return line
    
    column_map = COLUMN_RENAMES[table_name]
    
    # Replace each old column name with new one in backticks
    for old_col, new_col in column_map.items():
        line = line.replace(f'`{old_col}`', f'`{new_col}`')
    
    return line


def transform_checks_table(line):
    """
    Transform checks table INSERT statements:
    1. Add object_type column (source/sourcegroup based on sid/gid)
    2. Add object_id column (value from sid or gid)
    """
    # Parse the INSERT statement
    columns, value_tuples = parse_insert_statement(line, 'checks')
    if columns is None:
        return line
    
    # Find positions of sid and gid
    column_names = [col.strip('`') for col in columns]
    try:
        sid_index = column_names.index('sid')
        gid_index = column_names.index('gid')
    except ValueError:
        # If columns not found, return original line
        return line
    
    # Add new columns
    new_columns = columns + ['`object_type`', '`object_id`']
    
    # Transform each value tuple
    new_value_tuples = []
    for values in value_tuples:
        # Get sid and gid values
        sid_value = values[sid_index] if len(values) > sid_index else 'NULL'
        gid_value = values[gid_index] if len(values) > gid_index else 'NULL'
        
        # Determine object_type and object_id
        if sid_value != 'NULL' and sid_value.strip("'\"") != '':
            object_type = "'source'"
            object_id = sid_value
        elif gid_value != 'NULL' and gid_value.strip("'\"") != '':
            object_type = "'sourcegroup'"
            object_id = gid_value
        else:
            # Default fallback if both are null
            object_type = "'source'"
            object_id = 'NULL'
        
        # Add new values to the tuple
        new_values = values + [object_type, object_id]
        new_value_tuples.append(new_values)
    
    use_upsert = 'checks' in UPSERT_TABLES
    return build_insert_statement('checks', new_columns, new_value_tuples, use_upsert)


def generate_alerting_associations_insert(alerting_associations):
    """
    Generate INSERT statement for alerting_associations table from collected linker data.
    
    Args:
        alerting_associations: List of tuples (check_id, association_type, association_id)
    
    Returns:
        String containing the complete INSERT statement
    """
    if not alerting_associations:
        return ""
    
    # Generate the INSERT statement
    insert_lines = []
    insert_lines.append("--")
    insert_lines.append("-- Dumping data for table `alerting_associations`")
    insert_lines.append("--")
    insert_lines.append("")
    insert_lines.append("LOCK TABLES `alerting_associations` WRITE;")
    insert_lines.append("/*!40000 ALTER TABLE `alerting_associations` DISABLE KEYS */;")
    
    # Create VALUES tuples (id will be auto-increment, so we skip it)
    values_list = []
    for check_id, association_type, association_id in alerting_associations:
        values_list.append(f"({check_id},'{association_type}',{association_id})")
    
    values_str = ','.join(values_list)
    insert_line = f"INSERT INTO `alerting_associations` (`check_id`, `association_type`, `association_id`) VALUES {values_str};"
    insert_lines.append(insert_line)
    
    insert_lines.append("/*!40000 ALTER TABLE `alerting_associations` ENABLE KEYS */;")
    insert_lines.append("UNLOCK TABLES;")
    insert_lines.append("")
    
    return '\n'.join(insert_lines)


def add_admin_and_legacy_column_to_users(line, admin_users):
    """
    Add is_admin and legacy_password columns to users table INSERT statements.
    Determines admin status based on admin_users set.
    Sets legacy_password to 1 for all migrated users.
    """
    # Parse the INSERT statement
    columns, value_tuples = parse_insert_statement(line, 'users')
    if columns is None:
        return line
    
    # Add new columns
    new_columns = columns + ['`is_admin`', '`legacy_password`']
    
    # Transform each value tuple
    new_value_tuples = []
    for values in value_tuples:
        # Extract user ID (assuming it's the first column)
        user_id = values[0].strip("'\"") if values else None
        
        # Determine is_admin value
        is_admin = '1' if user_id in admin_users else '0'
        
        # Add new values to the tuple
        new_values = values + [is_admin, '1']  # legacy_password = 1 for all migrated users
        new_value_tuples.append(new_values)
    
    use_upsert = 'users' in UPSERT_TABLES
    return build_insert_statement('users', new_columns, new_value_tuples, use_upsert)


def filter_sql_dump(input_file, output_file, tables_to_remove):
    """
    Filter out specified tables from SQL dump file and rename tables/columns.
    Works with INSERT-only dump format.
    
    Args:
        input_file: Path to input SQL dump file
        output_file: Path to output filtered SQL dump file
        tables_to_remove: List of table names to remove
    """
    
    remove_set = set(tables_to_remove)
    
    # Track admin users from users_groups table
    admin_users = set()
    
    # Track alerting associations from linker tables
    alerting_associations = []
    
    # Mapping of linker tables to association types
    linker_table_mapping = {
        'nagiosna_ChecksCmdLinker': 'command',
        'nagiosna_ChecksHSALinker': 'nagios', 
        'nagiosna_ChecksTrapsLinker': 'snmp_receiver',
        'nagiosna_ChecksUsersLinker': 'user'
    }
    
    # First pass to collect admin users and alerting associations
    with open(input_file, 'r', encoding='latin1') as infile:
        current_table = None
        for line in infile:
            if line.startswith('-- Dumping data for table `nagiosna_users_groups`'):
                current_table = 'nagiosna_users_groups'
            elif current_table == 'nagiosna_users_groups' and line.startswith('INSERT INTO'):
                # Extract user IDs with group_id = 1 (admin)
                # Format: (id, user_id, group_id) - we want user_id where group_id = 1
                matches = re.findall(r'\(\d+,(\d+),1\)', line)
                for user_id in matches:
                    admin_users.add(user_id)
            elif line.startswith('UNLOCK TABLES;') and current_table == 'nagiosna_users_groups':
                current_table = None
            
            # Collect alerting associations from linker tables
            elif line.startswith('-- Dumping data for table `') and any(table in line for table in linker_table_mapping.keys()):
                for table_name in linker_table_mapping.keys():
                    if f'`{table_name}`' in line:
                        current_table = table_name
                        break
            elif current_table in linker_table_mapping and line.startswith('INSERT INTO'):
                association_type = linker_table_mapping[current_table]
                # Extract check_id and association_id from different linker table formats
                if current_table == 'nagiosna_ChecksCmdLinker':
                    # Format: (cid, cmdid) 
                    matches = re.findall(r'\((\d+),(\d+)\)', line)
                    for check_id, cmd_id in matches:
                        alerting_associations.append((check_id, association_type, cmd_id))
                elif current_table == 'nagiosna_ChecksHSALinker':
                    # Format: (cid, aid)
                    matches = re.findall(r'\((\d+),(\d+)\)', line)
                    for check_id, aid in matches:
                        alerting_associations.append((check_id, association_type, aid))
                elif current_table == 'nagiosna_ChecksTrapsLinker':
                    # Format: (cid, tid)
                    matches = re.findall(r'\((\d+),(\d+)\)', line)
                    for check_id, tid in matches:
                        alerting_associations.append((check_id, association_type, tid))
                elif current_table == 'nagiosna_ChecksUsersLinker':
                    # Format: (cid, uid)
                    matches = re.findall(r'\((\d+),(\d+)\)', line)
                    for check_id, uid in matches:
                        alerting_associations.append((check_id, association_type, uid))
            elif line.startswith('UNLOCK TABLES;') and current_table in linker_table_mapping:
                current_table = None
    
    # Regex patterns for INSERT-only dump format
    dump_header_pattern = re.compile(r'^-- Dumping data for table `([^`]+)`$')
    lock_tables_pattern = re.compile(r'^LOCK TABLES `([^`]+)`')
    insert_pattern = re.compile(r'^INSERT INTO `([^`]+)`')
    unlock_pattern = re.compile(r'^UNLOCK TABLES;$')
    alter_disable_pattern = re.compile(r'^/\*!40000 ALTER TABLE `([^`]+)` DISABLE KEYS \*/;$')
    alter_enable_pattern = re.compile(r'^/\*!40000 ALTER TABLE `([^`]+)` ENABLE KEYS \*/;$')
    
    # Use latin1 encoding for MySQL dumps which may contain non-UTF-8 data
    with open(input_file, 'r', encoding='latin1') as infile, \
         open(output_file, 'w', encoding='utf-8', errors='replace') as outfile:
        
        skip_section = False
        current_table = None
        current_table_renamed = None
        
        for line in infile:
            # Check for dump header comment
            match = dump_header_pattern.match(line)
            if match:
                current_table = match.group(1)
                
                if current_table in remove_set:
                    skip_section = True
                    print(f"Removing table: {current_table}")
                    continue
                else:
                    skip_section = False
                    if current_table in TABLE_RENAMES:
                        new_table_name = TABLE_RENAMES[current_table]
                        current_table_renamed = new_table_name
                        line = rename_table_in_line(line, current_table, new_table_name)
                        print(f"Renaming table: {current_table} -> {new_table_name}")
                    else:
                        current_table_renamed = current_table
            
            # Check for UNLOCK TABLES - end of section
            elif unlock_pattern.match(line):
                if not skip_section:
                    outfile.write(line)
                # Reset section tracking after UNLOCK
                skip_section = False
                current_table = None
                current_table_renamed = None
                continue
            
            # Process line if not skipping
            if not skip_section:
                # Handle LOCK TABLES
                lock_match = lock_tables_pattern.match(line)
                if lock_match:
                    old_table = lock_match.group(1)
                    if old_table in TABLE_RENAMES:
                        new_table = TABLE_RENAMES[old_table]
                        line = rename_table_in_line(line, old_table, new_table)
                        current_table_renamed = new_table
                
                # Handle ALTER TABLE DISABLE KEYS
                alter_disable_match = alter_disable_pattern.match(line)
                if alter_disable_match:
                    old_table = alter_disable_match.group(1)
                    if old_table in TABLE_RENAMES:
                        new_table = TABLE_RENAMES[old_table]
                        line = rename_table_in_line(line, old_table, new_table)
                
                # Handle ALTER TABLE ENABLE KEYS
                alter_enable_match = alter_enable_pattern.match(line)
                if alter_enable_match:
                    old_table = alter_enable_match.group(1)
                    if old_table in TABLE_RENAMES:
                        new_table = TABLE_RENAMES[old_table]
                        line = rename_table_in_line(line, old_table, new_table)
                
                # Handle INSERT statements
                insert_match = insert_pattern.match(line)
                if insert_match:
                    old_table = insert_match.group(1)
                    if old_table in TABLE_RENAMES:
                        new_table = TABLE_RENAMES[old_table]
                        line = rename_table_in_line(line, old_table, new_table)
                        current_table_renamed = new_table
                    
                    # Transform checks table (do this BEFORE removing columns)
                    if current_table_renamed == 'checks':
                        print("Transforming checks table: adding object_type and object_id columns")
                        line = transform_checks_table(line)
                    
                    # Remove id column if needed (do this BEFORE renaming columns)
                    if current_table_renamed:
                        line = remove_columns_from_insert(line, current_table_renamed)
                    
                    # Rename columns in INSERT statement
                    if current_table_renamed:
                        line = rename_columns_in_insert(line, current_table_renamed)
                    
                    # Add is_admin and legacy_password columns to users table
                    if current_table_renamed == 'users':
                        line = add_admin_and_legacy_column_to_users(line, admin_users)
                
                outfile.write(line)
        
        # Generate and write alerting_associations table at the end
        if alerting_associations:
            print(f"Creating alerting_associations table with {len(alerting_associations)} associations")
            alerting_insert = generate_alerting_associations_insert(alerting_associations)
            outfile.write(alerting_insert)
        
        print(f"\nFiltering complete. Output written to: {output_file}")


def main():
    parser = argparse.ArgumentParser(
        description='Filter unwanted tables from MySQL/MariaDB SQL dumps (INSERT-only format) and rename tables/columns'
    )
    parser.add_argument(
        'input_file',
        help='Input SQL dump file'
    )
    parser.add_argument(
        'output_file',
        help='Output filtered SQL dump file'
    )
    parser.add_argument(
        '-l', '--list-tables',
        action='store_true',
        help='List all tables in the dump without filtering'
    )
    parser.add_argument(
        '--show-mappings',
        action='store_true',
        help='Show table and column rename mappings'
    )
    
    args = parser.parse_args()
    
    if args.show_mappings:
        print("Table Renames:")
        for old, new in TABLE_RENAMES.items():
            print(f"  {old} -> {new}")
        
        print("\nColumn Renames:")
        for table, columns in COLUMN_RENAMES.items():
            print(f"  {table}:")
            for old_col, new_col in columns.items():
                print(f"    {old_col} -> {new_col}")
        return
    
    if args.list_tables:
        # List tables found in the dump
        dump_header_pattern = re.compile(r'^-- Dumping data for table `([^`]+)`$')
        tables_found = []
        
        with open(args.input_file, 'r', encoding='latin1') as f:
            for line in f:
                match = dump_header_pattern.match(line)
                if match:
                    tables_found.append(match.group(1))
        
        print("Tables found in dump:")
        for table in tables_found:
            removed = " (WILL BE REMOVED)" if table in TABLES_TO_REMOVE else ""
            renamed = f" (WILL BE RENAMED to {TABLE_RENAMES[table]})" if table in TABLE_RENAMES else ""
            print(f"  - {table}{removed}{renamed}")
        return
    
    # Perform filtering
    filter_sql_dump(args.input_file, args.output_file, TABLES_TO_REMOVE)


if __name__ == '__main__':
    main()