import os
import re
import sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import logging
from tqdm import tqdm
import random
import warnings
import matplotlib.pyplot as plt
import itertools
import pickle

warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*weights_only.*")

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')

# Dataset class
class VideoFeaturesDataset(Dataset):
    def __init__(self, features_dir, labels_file, scaler=None, pca=None, augment=False, noise_std=0.01):
        self.features_dir = features_dir
        self.label_map = self.load_labels(labels_file)
        self.data = list(self.label_map.keys())
        self.labels = [self.label_map[filename] for filename in self.data]

        self.scaler = scaler
        self.pca = pca

        if self.scaler is None or self.pca is None:
            self.scaler, self.pca = self.fit_normalization()

        self.augment = augment
        self.noise_std = noise_std

    def load_labels(self, labels_file):
        label_map = {}
        event_types = set()
        event_type_pattern = re.compile(r'^[A-Za-z_]+')
        try:
            with open(labels_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) == 2:
                        filename, _ = parts
                        match = event_type_pattern.match(filename)
                        if match:
                            event_type = match.group(0)
                            event_types.add(event_type)
                    else:
                        logging.warning(f"Invalid line in labels file: {line}")
        except Exception as e:
            logging.error(f"Error reading labels file: {e}")
            raise e

        event_types = sorted(event_types)
        if 'Normal_Videos' in event_types:
            event_types.remove('Normal_Videos')
            event_types.insert(0, 'Normal_Videos')
        else:
            logging.warning("Warning: 'Normal_Videos' not found in event types.")

        self.event_type_to_index = {event_type: index for index, event_type in enumerate(event_types)}

        try:
            with open(labels_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) == 2:
                        filename, _ = parts
                        match = event_type_pattern.match(filename)
                        if match:
                            event_type = match.group(0)
                            label = self.event_type_to_index[event_type]
                            label_map[filename] = label
                    else:
                        logging.warning(f"Invalid line in labels file: {line}")
        except Exception as e:
            logging.error(f"Error reading labels file: {e}")
            raise e

        return label_map

    def fit_normalization(self):
        all_features = []
        for filename in tqdm(self.data, desc="Fitting scaler and PCA"):
            feature_file = f"{filename}_features.pt"
            feature_path = os.path.join(self.features_dir, feature_file)
            if os.path.exists(feature_path):
                try:
                    features = torch.load(feature_path, map_location='cpu')
                    if isinstance(features, tuple):
                        features = features[0]
                    features = features.view(-1).numpy()
                    all_features.append(features)
                except Exception as e:
                    logging.error(f"Error loading features for {filename}: {e}")
            else:
                logging.warning(f"Feature file not found for {filename}, skipping.")

        if not all_features:
            raise ValueError("No feature files were loaded.")

        all_features = np.stack(all_features)

        scaler = StandardScaler()
        scaler.fit(all_features)

        pca = PCA(n_components=0.99)
        pca.fit(scaler.transform(all_features))

        return scaler, pca

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename = self.data[idx]
        label = self.labels[idx]
        feature_file = f"{filename}_features.pt"
        feature_path = os.path.join(self.features_dir, feature_file)

        try:
            features = torch.load(feature_path, map_location='cpu')
            if isinstance(features, tuple):
                features = features[0]
            features = features.view(-1).numpy()
            features = self.scaler.transform([features])
            features = self.pca.transform(features)
            features = features[0]

            if self.augment:
                features = self.add_noise(features)

            features = torch.tensor(features, dtype=torch.float32)
        except Exception as e:
            logging.error(f"Error loading features for {filename}: {e}")
            raise e

        return features, label

    def add_noise(self, features):
        noise = np.random.normal(0, self.noise_std, size=features.shape)
        return features + noise

class Classifier(nn.Module):
    def __init__(self, input_size, num_classes, activation='ReLU'):
        super(Classifier, self).__init__()
        if activation == 'ReLU':
            act = nn.ReLU()
        elif activation == 'LeakyReLU':
            act = nn.LeakyReLU()
        elif activation == 'ELU':
            act = nn.ELU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        self.net = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.BatchNorm1d(1024),
            act,
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            act,
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            act,
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.net(x)

class FocalLoss(nn.Module):
    def __init__(self, gamma=1, reduction='mean', device='cpu', num_classes=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.num_classes = num_classes
        self.device = device
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss)
        if self.reduction == 'mean':
            return focal_loss.mean()
        else:
            return focal_loss.sum()

def train(model, dataloader, criterion, optimizer, device, num_classes):
    model.train()
    total_loss = 0
    all_labels = []
    all_predictions = []

    for features, labels in tqdm(dataloader, desc="Training", leave=False):
        features, labels = features.to(device), labels.to(device).long()
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * features.size(0)
        predicted = torch.argmax(outputs, dim=1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

    accuracy = (np.array(all_predictions) == np.array(all_labels)).mean() * 100
    precision, recall, f1_score, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted', zero_division=0)
    return total_loss / len(dataloader.dataset), accuracy, precision*100, recall*100, f1_score*100

def validate(model, dataloader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for features, labels in tqdm(dataloader, desc="Validating", leave=False):
            features, labels = features.to(device), labels.to(device).long()
            outputs = model(features)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * features.size(0)
            predicted = torch.argmax(outputs, dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = (np.array(all_predictions) == np.array(all_labels)).mean() * 100
    precision, recall, f1_score, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted', zero_division=0)
    return total_loss / len(dataloader.dataset), accuracy, precision*100, recall*100, f1_score*100

def main():
    # Paths and directories
    features_dir = "/home/gustavo/smart_doorbell/data/ucf-crime/features"
    labels_file = "/home/gustavo/smart_doorbell/data/ucf-crime/labels.txt"
    model_save_dir = "/home/gustavo/smart_doorbell/models/multiclass_classifier"
    checkpoint_dir = "/home/gustavo/smart_doorbell/models/checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Hyperparameters
    batch_size = 32
    num_epochs = 100
    k_folds = 5
    patience = 20
    delta = 0.001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Hyperparameter search space
    learning_rates = [1e-4, 5e-4, 1e-3]
    weight_decays = [1e-4, 5e-4, 1e-3]
    betas_list = [(0.9, 0.999), (0.85, 0.995)]
    activations = ['ReLU', 'LeakyReLU', 'ELU']
    noise_stds = [0.005, 0.01, 0.02]
    seeds = [42, 123, 456]

    hyperparameter_combinations = list(itertools.product(
        seeds, learning_rates, weight_decays, betas_list, activations, noise_stds
    ))

    total_combinations = len(hyperparameter_combinations)
    combination_index = 0

    checkpoint_path = os.path.join(checkpoint_dir, "training_checkpoint.pkl")
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, 'rb') as f:
            checkpoint = pickle.load(f)
        combination_index = checkpoint['combination_index']
        logging.info(f"Resuming training from combination {combination_index + 1}/{total_combinations}")
    else:
        checkpoint = None

    ensemble_models = []

    # Hyperparameter tuning loop
    for idx in range(combination_index, total_combinations):
        (seed, lr, weight_decay, betas, activation, noise_std) = hyperparameter_combinations[idx]
        config_id = f"seed{seed}_lr{lr}_wd{weight_decay}_betas{betas[0]}_{betas[1]}_act{activation}_noise{noise_std}"

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        logging.info(f"\nStarting training with config {idx+1}/{total_combinations}: {config_id}")

        dataset = VideoFeaturesDataset(features_dir, labels_file, noise_std=noise_std)
        logging.info("Event Type to Index Mapping:")
        for event_type, index in dataset.event_type_to_index.items():
            logging.info(f"{event_type}: {index}")
        num_classes = len(dataset.event_type_to_index)

        skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=seed)
        labels_array = np.array(dataset.labels)

        writer = SummaryWriter(log_dir=os.path.join("runs", "multiclass_classifier", config_id))
        fold = 0
        if checkpoint and checkpoint['config_id'] == config_id:
            fold = checkpoint['fold']
            logging.info(f"Resuming from fold {fold + 1}/{k_folds}")
        else:
            fold = 0

        for train_index, val_index in list(skf.split(np.zeros(len(labels_array)), labels_array))[fold:]:
            fold += 1
            logging.info(f"\nStarting fold {fold}/{k_folds}")

            scaler = dataset.scaler
            pca = dataset.pca

            train_dataset = VideoFeaturesDataset(features_dir, labels_file, scaler=scaler, pca=pca, augment=True, noise_std=noise_std)
            train_dataset.data = [dataset.data[i] for i in train_index]
            train_dataset.labels = [dataset.labels[i] for i in train_index]

            val_dataset = VideoFeaturesDataset(features_dir, labels_file, scaler=scaler, pca=pca, augment=False)
            val_dataset.data = [dataset.data[i] for i in val_index]
            val_dataset.labels = [dataset.labels[i] for i in val_index]

            train_labels = np.array(train_dataset.labels)
            class_sample_counts = np.array([len(np.where(train_labels == t)[0]) for t in np.unique(train_labels)])
            weights = 1. / class_sample_counts
            samples_weights = np.array([weights[t] for t in train_labels])
            samples_weights = torch.from_numpy(samples_weights).double()
            sampler = WeightedRandomSampler(samples_weights, len(samples_weights))

            train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

            sample_features, _ = train_dataset[0]
            input_size = sample_features.size(0)
            model = Classifier(input_size, num_classes, activation=activation).to(device)
            criterion = FocalLoss(gamma=1, device=device, num_classes=num_classes)
            optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)

            best_val_loss = float('inf')
            best_val_acc = 0.0
            epochs_no_improve = 0
            start_epoch = 0
            if checkpoint and checkpoint['config_id'] == config_id and checkpoint['fold'] == fold:
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                best_val_loss = checkpoint['best_val_loss']
                best_val_acc = checkpoint['best_val_acc']
                epochs_no_improve = checkpoint['epochs_no_improve']
                logging.info(f"Resuming from epoch {start_epoch + 1}/{num_epochs}")

            try:
                for epoch in range(start_epoch, num_epochs):
                    train_loss, train_acc, train_precision, train_recall, train_f1 = train(model, train_loader, criterion, optimizer, device, num_classes)
                    val_loss, val_acc, val_precision, val_recall, val_f1 = validate(model, val_loader, criterion, device, num_classes)

                    writer.add_scalars(f'Fold_{fold}/Loss', {'Train': train_loss, 'Validation': val_loss}, epoch)
                    writer.add_scalars(f'Fold_{fold}/Accuracy', {'Train': train_acc, 'Validation': val_acc}, epoch)
                    writer.add_scalars(f'Fold_{fold}/Precision', {'Train': train_precision, 'Validation': val_precision}, epoch)
                    writer.add_scalars(f'Fold_{fold}/Recall', {'Train': train_recall, 'Validation': val_recall}, epoch)
                    writer.add_scalars(f'Fold_{fold}/F1_Score', {'Train': train_f1, 'Validation': val_f1}, epoch)

                    logging.info(f"Config {config_id}, Fold [{fold}], Epoch [{epoch+1}/{num_epochs}]: "
                                 f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

                    if val_loss < best_val_loss - delta or val_acc > best_val_acc + delta:
                        if val_loss < best_val_loss - delta:
                            best_val_loss = val_loss
                        if val_acc > best_val_acc + delta:
                            best_val_acc = val_acc
                        epochs_no_improve = 0
                        model_save_path = os.path.join(model_save_dir, f"best_model_{config_id}_fold{fold}.pth")
                        torch.save(model.state_dict(), model_save_path)
                    else:
                        epochs_no_improve += 1
                        if epochs_no_improve >= patience:
                            logging.info(f"Early stopping on epoch {epoch+1}")
                            break

                    checkpoint = {
                        'combination_index': idx,
                        'config_id': config_id,
                        'fold': fold,
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_val_loss': best_val_loss,
                        'best_val_acc': best_val_acc,
                        'epochs_no_improve': epochs_no_improve
                    }
                    with open(checkpoint_path, 'wb') as f:
                        pickle.dump(checkpoint, f)

                ensemble_models.append({
                    'model': model,
                    'config_id': config_id,
                    'fold': fold
                })

                checkpoint = {
                    'combination_index': idx,
                    'config_id': config_id,
                    'fold': fold,
                    'epoch': 0,
                    'model_state_dict': None,
                    'optimizer_state_dict': None,
                    'best_val_loss': float('inf'),
                    'best_val_acc': 0.0,
                    'epochs_no_improve': 0
                }
                with open(checkpoint_path, 'wb') as f:
                    pickle.dump(checkpoint, f)

            except KeyboardInterrupt:
                logging.info("Training interrupted. Saving checkpoint...")
                checkpoint = {
                    'combination_index': idx,
                    'config_id': config_id,
                    'fold': fold,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_loss': best_val_loss,
                    'best_val_acc': best_val_acc,
                    'epochs_no_improve': epochs_no_improve
                }
                with open(checkpoint_path, 'wb') as f:
                    pickle.dump(checkpoint, f)
                sys.exit()

        combination_index = idx + 1
        checkpoint = {
            'combination_index': combination_index,
            'config_id': None,
            'fold': 0,
            'epoch': 0,
            'model_state_dict': None,
            'optimizer_state_dict': None,
            'best_val_loss': float('inf'),
            'best_val_acc': 0.0,
            'epochs_no_improve': 0
        }
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint, f)

        writer.close()
        logging.info(f"Completed configuration {combination_index}/{total_combinations}")

    logging.info("Training complete.")

    # Ensemble prediction on entire dataset
    all_labels = []
    all_predictions = []
    dataset = VideoFeaturesDataset(features_dir, labels_file, augment=False)
    event_type_names = [et for et, _ in sorted(dataset.event_type_to_index.items(), key=lambda x: x[1])]

    with torch.no_grad():
        for idx in tqdm(range(len(dataset)), desc="Ensemble Prediction"):
            features, label = dataset[idx]
            features = features.to(device).unsqueeze(0)
            outputs = []
            for entry in ensemble_models:
                model = entry['model']
                model.eval()
                output = model(features)
                probs = torch.softmax(output, dim=1).cpu().numpy()
                outputs.append(probs)
            avg_probs = np.mean(outputs, axis=0)
            predicted_class = np.argmax(avg_probs)
            all_predictions.append(predicted_class)
            all_labels.append(label)

    accuracy = (np.array(all_predictions) == np.array(all_labels)).mean() * 100
    logging.info(f"\nEnsemble Performance on Entire Dataset:")
    logging.info(f"Accuracy: {accuracy:.2f}%")
    logging.info("\nClassification Report:")
    logging.info(classification_report(all_labels, all_predictions, target_names=event_type_names))
    logging.info("\nConfusion Matrix:")
    logging.info(confusion_matrix(all_labels, all_predictions))

if __name__ == "__main__":
    main()
