"""
generate_log_query.py
~~~~~~~~~~~~~~~~~~~~~

This module generates lucene queries from natural language using specified AI models.
"""

import sys
from datetime import datetime
import argparse
import json
# OpenAI version is pinned as "openai<1.23.6"
from openai import OpenAI
import anthropic
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
import requests

def parse_args():
    """Returns configured command line arguments for natural language queries

    Fields are the lucene keywords or mappings defined in NLS, such as '@timestamp', '@version', etc.
    A "Self Hosted" provider refers to a model that we may allow customers to setup on their own infrastructure. 
    This currently looks like a vLLM instance with a mirrored OpenAI API scheme. The provider address points
    to this custom instance.
    """
    parser = argparse.ArgumentParser(description='Generate Lucene queries from natural language.')
    parser.add_argument('--api_key', default='', help='API key for the AI provider')
    parser.add_argument('--provider_address', default=None, help='IP Address of custom LLM container')
    parser.add_argument('--provider_port', default="8000", help='Port of the LLM container')
    parser.add_argument('--natural_language_query', required=True, help='Natural language input for query')
    parser.add_argument('--current_fields', default='', help='Fields available for query formulation')
    parser.add_argument('--model_provider', choices=['openai', 'mistral', 'anthropic', 'self_hosted'], default='openai', help='Choice of AI model provider')
    return parser.parse_args()

def query_openai(api_key, natural_language_query, system_prompt):
    client = OpenAI(
        api_key=api_key,
    )

    response = client.chat.completions.create(
        model='gpt-4-turbo',
        messages=[
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': natural_language_query}
        ],
        temperature=0,
        stream=False,
        max_tokens=150
    )
    print(response.choices[0].message.content)

def query_mistral(api_key, natural_language_query, system_prompt):
    client = MistralClient(api_key=api_key)

    response = client.chat(
        model="mistral-large-latest",
        messages=[
            ChatMessage(role="system", content=system_prompt), 
            ChatMessage(role="user", content=natural_language_query), 
            ]
    )

    print(response.choices[0].message.content)

def query_anthropic(api_key, natural_language_query, system_prompt):
    client = anthropic.Anthropic(
        api_key=api_key
    )

    response = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=150,
        temperature=0.0,
        system=system_prompt,
        messages=[
            {"role": "user", "content": natural_language_query}
        ]
    )

    print(response.content[0].text)

def query_custom_model(provider_address, provider_port, natural_language_query, system_prompt):

    client = OpenAI(
        base_url=f"{provider_address}:{provider_port}/v1",
        api_key="token"
    )

    response = client.chat.completions.create(
        model='models/llama-3-lucene-8b',
        messages=[
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': natural_language_query}
        ],
        temperature=0,
        max_tokens=150,
        extra_body={"stop_token_ids":[128009]}
    )

    print(response.choices[0].message.content)


def handle_query(model_provider, api_key, natural_language_query, system_prompt, provider_address=None, provider_port=8000):
    model_handlers = {
        'openai': query_openai,
        'mistral': query_mistral,
        "anthropic": query_anthropic,
        "self_hosted": query_custom_model
    }
    
    handler = model_handlers.get(model_provider)
    if handler and provider_address is None:
        handler(api_key, natural_language_query, system_prompt)
    elif handler and provider_address:
        handler(provider_address, provider_port, natural_language_query, system_prompt)
    else:
        print(f"Error: Model {model_provider} not supported.")
        sys.exit(1)

def main():
    args = parse_args()
    if not args.natural_language_query.strip():
        print("*")
        sys.exit(1)

    # limit current fields to 500 characters
    fields = args.current_fields[:500]

    current_date_iso8601 = datetime.now().isoformat()
    system_prompt = f"""
    You only output lucene queries. You are embedded in a Nagios Log Server Input field. When you receive user input, you are outputting a lucene query. 
    These are the current available fields to the user: {fields}

    Current date: {current_date_iso8601}

    Note you're not aware of the user's configuration, so you should be careful with AND and ORs, allowing for the possibility that their configuration is slightly different.

    Examples:

    User: I want to see system logs from localhost
    You: type:"syslog" AND host:"127.0.0.1"

    User: Windows events
    You: type: "eventlog" OR type:*event*

    User: General Apache Logs
    You: type:"apache_access"

    User: logs from yesterday (today is oct 20)
    You: @timestamp:[2023-10-19T00:00:00 TO 2023-10-19T23:59:59]

    User: I want to see apache 404 errors
    You: type:"apache_access" AND response:404

    Notably, you should use quotations ("") whenever it is non-numeric. For example:

    User: I want to see GET requests
    You: verb:"GET"

    Do not explain yourself. Your output is going directly into the input field where the query will be executed. 

    User: Windows Updates
    You: ((EventID: 17 AND (message:"Windows" OR message:"update")) OR (EventID: 19 AND (message:"Windows" OR message:"update")) OR (EventID: 20 AND (message:"Windows" OR message:"update")) OR (EventID: 24 AND (message:"Windows" OR message:"update")) OR (EventID: 25 AND (message:"Windows" OR message:"update")) OR (EventID: 35 AND (message:"Windows" OR message:"update")))

    User: Firewall stuff
    You: ((EventID:2005 AND message:"firewall") OR EventID:2005 AND Severity:INFO AND Channel:"Microsoft-Windows-Windows Firewall With Advanced Security/Firewall") OR (EventID:(2006 OR 2033) AND Severity:INFO AND Channel:"Microsoft-Windows-Windows Firewall With Advanced Security/Firewall") OR (EventID:2009 AND Severity:ERROR AND Channel:"Microsoft-Windows-Windows Firewall With Advanced Security/Firewall") OR (EventID:2004 AND Severity:INFO AND Channel:"Microsoft-Windows-Windows Firewall With Advanced Security/Firewall") OR (EventID:4946 AND message:"added") OR (EventID:4947 AND message:"modified") OR (EventID:4950 AND message:"changed") OR (EventID:4954 AND message:"changed") OR (EventID:5025 AND message:"stopped") OR (EventID:5031 AND message:"blocked")

    User: Account Lockouts
    You: (EventID:4740 AND (EventType:AUDIT_SUCCESS OR EventType:AUDIT_FAILURE)) OR (message:"locked out" AND type:eventlog)

    Notes: 

    source:"/var/log/secure" OR message:"sshd" or program:sshd AND message:"error" and source:/var/log/secure/ AND program:sshd is valid but source:/var/log/secure AND message:"sshd" AND program:sshd would not be because of the slash.

    You need to be wary of types. For example, severity:error will throw an error because severity can only be an integer.

    Use "AND" only if the user explicitly specifies multiple criteria that need to be met simultaneously. Otherwise, favor "OR" to cast a wider net.

    DO NOT drag the query out. for example:

    User: Show me network outages - use Windows Event IDs.
    You: (EventID:4000 OR EventID:4001 OR EventID:4002 OR EventID:4003 OR EventID:4004 OR EventID:4005 OR EventID:4006 OR EventID:4007 OR EventID:4010 OR EventID:4011 OR EventID:4012 OR EventID:4013 OR EventID:4014 OR EventID:4015 OR EventID:4016 OR EventID:4017 OR EventID:4018 OR EventID:4019 OR EventID:4020 OR EventID:4021 OR EventID:4022 OR EventID:4023 OR EventID:4024 OR EventID:4025 OR EventID:4026 OR...

    Do not do that. Instead, find the specific relevant queries and make it finite.

    If the user's input is invalid or irrelevant, for example unrelated to IT, only say __INVALID REQUEST__
    """    
    handle_query(args.model_provider, args.api_key, args.natural_language_query, system_prompt, args.provider_address, args.provider_port)

if __name__ == "__main__":
    main()
