import argparse
import sys
import requests
from datetime import datetime, timedelta

def parse_args():
    parser = argparse.ArgumentParser(description='Check OpenAI API usage and costs')
    parser.add_argument('-k', '--api-key', required=True, help='OpenAI Admin API key')
    parser.add_argument('-C', '--check', required=True, choices=["COST", "INPUT_TOKENS", "OUTPUT_TOKENS", "CACHED_TOKENS", "NUM_REQUESTS"], help='What to check eg. COST, INPUT_TOKENS')
    parser.add_argument('-w', '--warning', nargs='?', type=float, const=-1, default=-1, help='Warning threshold')
    parser.add_argument('-c', '--critical', nargs='?', type=float, const=-1, default=-1, help='Critical threshold')
    parser.add_argument('-p', '--period', choices=['day', 'week', 'month'], default='day',
                      help='Time period to check (day, week, month)')
    parser.add_argument('-t', '--time-quantity', type=int, default=1,
                      help='Number of time periods to look back (e.g., -p week -t 5 for 5 weeks)')
    parser.add_argument('--debug', action='store_true', help='Enable debug output')
    return parser.parse_args()

def get_time_range(period, time_quantity=1):
    now = datetime.now()
    
    if period == 'day':
        start = now - timedelta(days=time_quantity)
        return int(start.timestamp())
    elif period == 'week':
        start = now - timedelta(weeks=time_quantity)
        return int(start.timestamp())
    elif period == 'month':
        start = now - timedelta(days=30 * time_quantity)
        return int(start.timestamp())

def get_data(api_key, check, period='day', time_quantity=1, debug=False):
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }

    start_time = get_time_range(period, time_quantity)
    
    if check == 'INPUT_TOKENS' or check == 'OUTPUT_TOKENS' or check == 'NUM_REQUESTS' or check == 'CACHED_TOKENS':
        url = 'https://api.openai.com/v1/organization/usage/completions'
        if debug:
            print(f"[DEBUG] Fetching usage data from {datetime.fromtimestamp(start_time)}")
    
        params = {
            'start_time': start_time,
            'limit': 31
        }

        try:
            if debug:
                print(f"[DEBUG] Requesting: {url}")
                print(f"[DEBUG] With params: {params}")
            response = requests.get(url, headers=headers, params=params)
            if debug:
                print(f"[DEBUG] Full request URL: {response.url}")
                print(f"[DEBUG] Status code: {response.status_code}")
                print(f"[DEBUG] Response: {response.text}")
            
            if response.status_code == 401:
                print("UNKNOWN - Invalid API key")
                sys.exit(3)
            if response.status_code == 403:
                print("UNKNOWN - Forbidden: Check if API key has billing permissions")
                sys.exit(3)
            
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            print(f"UNKNOWN - Error fetching OpenAI usage data: {str(e)}")
            sys.exit(3)
        tokens = 0
        if check == 'INPUT_TOKENS':
            for bucket in response.json().get('data', []):
                for result in bucket.get('results', []):
                    tokens += result.get('input_tokens', 0)
            return ["input_tokens", tokens, "tokens"]
        elif check == 'OUTPUT_TOKENS':
            for bucket in response.json().get('data', []):
                for result in bucket.get('results', []):
                    tokens += result.get('output_tokens', 0)
            return ["output_tokens", tokens, "tokens"]
        elif check == 'CACHED_TOKENS':
            for bucket in response.json().get('data', []):
                for result in bucket.get('results', []):
                    tokens += result.get('input_cached_tokens', 0)
            return ["input_cached_tokens", tokens, "tokens"]
        elif check == 'NUM_REQUESTS':
            for bucket in response.json().get('data', []):
                for result in bucket.get('results', []):
                    tokens += result.get('num_model_requests', 0)
            return ["model_requests", tokens, "requests"]




    elif  check == 'COST':
        url = 'https://api.openai.com/v1/organization/costs' 
        start_time = get_time_range(period, time_quantity)
    
        params = {
            'start_time': start_time,
            'limit': 180
        }

        try:
            if debug:
                print(f"[DEBUG] Requesting: {url}")
                print(f"[DEBUG] With params: {params}")
            response = requests.get(url, headers=headers, params=params)
            if debug:
                print(f"[DEBUG] Full request URL: {response.url}")
                print(f"[DEBUG] Status code: {response.status_code}")
                print(f"[DEBUG] Response: {response.text}")
            
            if response.status_code == 401:
                print("UNKNOWN - Invalid API key")
                sys.exit(3)
            if response.status_code == 403:
                print("UNKNOWN - Forbidden: Check if API key has billing permissions")
                sys.exit(3)
            
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            print(f"UNKNOWN - Error fetching OpenAI usage data: {str(e)}")
            sys.exit(3)
        cost = 0
        for bucket in response.json().get('data', []):
            for result in bucket.get('results', []):
                cost += result.get('amount', {}).get('value', 0)  # Convert USD to cents
        currency = "usd"
        for bucket in response.json().get('data', []):
            for result in bucket.get('results', []):
                if 'currency' in result.get('amount', {}):
                    currency = result.get('amount', {}).get('currency', 'usd')
                    break
    
        return ["cost", cost, currency]

def main():
    args = parse_args()
    usage_data = get_data(args.api_key, args.check, args.period, args.time_quantity, args.debug)
    if args.debug:
        print(f"[DEBUG] Usage data: {usage_data}")
    check = usage_data[0]
    total_cost = usage_data[1]
    currency = usage_data[2]
    if args.debug:
        print(f"[DEBUG] Total {check}: {total_cost} {currency}")
    # Prepare performance data
    
    # Check thresholds and set status
    if args.critical > 0 and total_cost >= args.critical:
        status = "CRITICAL"
        exit_code = 2
    elif args.warning > 0 and total_cost >= args.warning:
        status = "WARNING"
        exit_code = 1
    else:
        status = "OK"
        exit_code = 0
    
    if args.time_quantity > 1:
        time_period = f"{args.time_quantity} {args.period}s"
    else:
        time_period = f"{args.time_quantity} {args.period}"
    print(f"{status} - {check}: {total_cost} {currency} (Time Period: {time_period}) | {check}={total_cost};{args.warning};{args.critical};;")
    sys.exit(exit_code)

if __name__ == "__main__":
    main()