import csv
import os
import psycopg2

# Function to parse genus accuracies from a CSV file
def parse_genus_accuracies(file_path):
    accuracies = {}
    with open(file_path, 'r') as file:
        reader = list(csv.reader(file))
        total_rows = len(reader)
        for row in reader[1:total_rows-3]:  # Skip header and last 3 rows
            name = row[0].strip()
            f1_score = round(float(row[3].strip()), 2)  # Assuming F1-score is in the 4th column
            accuracies[name] = f1_score
    return accuracies

# Function to parse species accuracies from CSV files
def parse_species_accuracies(file_path):
    accuracies = {}
    with open(file_path, 'r') as file:
        reader = list(csv.reader(file))
        total_rows = len(reader)
        for row in reader[1:total_rows-3]:  # Skip header and last 3 rows
            name = row[0].strip().replace('_', ' ')  # Replace underscore with space in species name
            f1_score = round(float(row[3].strip()), 2)  # Assuming F1-score is in the 4th column
            accuracies[name] = f1_score
    return accuracies

# Function to get image counts from the database
def get_image_counts():
    image_counts = {}
    # Database connection details
    conn = psycopg2.connect(
        host="sp2.cs.vt.edu",
        database="pelagic",
        user="francesco",
        password="pelagicMaster"
    )
    cur = conn.cursor()
    
    # Query to get species and their image counts
    cur.execute("""
        SELECT genus, species, COUNT(*)
        FROM training_copy
        WHERE species IS NOT NULL AND genus IS NOT NULL
        GROUP BY genus, species
    """)
    
    rows = cur.fetchall()
    for row in rows:
        genus_species = f"{row[1]}"
        image_counts[genus_species] = row[2]
    
    cur.close()
    conn.close()
    
    return image_counts

# Function to get the list of valid shark genera
def get_shark_genera():
    shark_genera = set()
    # Database connection details
    conn = psycopg2.connect(
        host="sp2.cs.vt.edu",
        database="pelagic",
        user="francesco",
        password="pelagicMaster"
    )
    cur = conn.cursor()
    
    # Query to get the list of shark genera
    cur.execute("""
        SELECT DISTINCT genus_name 
        FROM taxonomy3 
        WHERE superorder = 'Selachimorpha';
    """)
    
    rows = cur.fetchall()
    for row in rows:
        shark_genera.add(row[0].strip())
    
    cur.close()
    conn.close()
    
    return shark_genera

# List of genera with only one species and their corresponding species
single_species_genera = {
    'Carcharodon': 'Carcharodon carcharias',
    'Carcharias': 'Carcharias taurus',
    'Rhincodon': 'Rhincodon typus',
    'Prionace': 'Prionace glauca',
    'Galeocerdo': 'Galeocerdo cuvier',
    'Triaenodon': 'Triaenodon obesus'
}

# File paths
genus_metrics_file = '/home/spr/SDv4/metrics/classification_report.csv'
species_metrics_dir = '/home/spr/SDv4/species/metrics'

# Parse genus accuracies
genus_accuracies = parse_genus_accuracies(genus_metrics_file)

# Parse species accuracies from CSV files in subdirectories
species_accuracies = {}
for genus in os.listdir(species_metrics_dir):
    species_metrics_file = f"{species_metrics_dir}/{genus}/{genus}_classification_report.csv"
    
    # Check if the species-specific classification report exists
    if os.path.exists(species_metrics_file):
        species_accuracies.update(parse_species_accuracies(species_metrics_file))

# Add genus accuracy to single species genera
for genus, species in single_species_genera.items():
    if species not in species_accuracies:
        species_accuracies[species] = genus_accuracies.get(genus, 0)

# Get image counts from the database
image_counts = get_image_counts()

# Get the list of valid shark genera
shark_genera = get_shark_genera()

# Ensure all species from the database are included, even if not in classification reports
for genus_species in image_counts:
    genus_name = genus_species.split()[0]  # Extract genus name from the full genus_species string
    if genus_name in shark_genera:
        if genus_species not in species_accuracies:
            species_accuracies[genus_species] = 0  # Default accuracy for species not present in classification reports

# Prepare data for CSV output
csv_data = []
for species, accuracy in species_accuracies.items():
    image_count = image_counts.get(species, 0)  # Use the correct image count from the database
    csv_data.append([species, image_count, accuracy])

# Sort the data alphabetically by species name
csv_data.sort(key=lambda x: x[0])

# Write CSV file
csv_file_path = 'sd_table_all.csv'
with open(csv_file_path, 'w', newline='') as csv_file:
    writer = csv.writer(csv_file)
    # Write CSV headers
    writer.writerow(['Species', 'Images', 'Accuracy'])
    # Write data to CSV
    writer.writerows(csv_data)

print(f"CSV file '{csv_file_path}' has been created.")
