Source code for submarit.evaluation.visualization

"""Visualization utilities for evaluation results."""

from typing import List, Optional, Union, Dict, Any
import numpy as np
from numpy.typing import NDArray
import warnings

from submarit.evaluation.cluster_evaluator import ClusterEvaluationResult
from submarit.evaluation.gap_statistic import GAPStatisticResult
from submarit.evaluation.entropy_evaluator import EntropyClusterResult


[docs] class EvaluationVisualizer: """Visualization utilities for SUBMARIT evaluation results.""" def __init__(self): """Initialize the visualizer.""" self._check_matplotlib() def _check_matplotlib(self) -> bool: """Check if matplotlib is available.""" try: import matplotlib.pyplot as plt self.plt = plt return True except ImportError: warnings.warn("matplotlib not available for visualization") self.plt = None return False
[docs] def plot_cluster_comparison( self, results: List[ClusterEvaluationResult], labels: Optional[List[str]] = None, metrics: Optional[List[str]] = None, save_path: Optional[str] = None ) -> None: """Compare multiple clustering results. Args: results: List of evaluation results to compare labels: Labels for each result metrics: Metrics to plot (default: all) save_path: Optional path to save the plot """ if not self.plt: return if metrics is None: metrics = ["diff", "z_value", "log_lh", "scaled_diff"] if labels is None: labels = [f"Result {i+1}" for i in range(len(results))] n_metrics = len(metrics) fig, axes = self.plt.subplots(1, n_metrics, figsize=(4*n_metrics, 5)) if n_metrics == 1: axes = [axes] for idx, metric in enumerate(metrics): ax = axes[idx] values = [getattr(r, metric) for r in results] bars = ax.bar(range(len(results)), values) ax.set_xlabel("Clustering") ax.set_ylabel(metric.replace("_", " ").title()) ax.set_title(f"{metric.replace('_', ' ').title()} Comparison") ax.set_xticks(range(len(results))) ax.set_xticklabels(labels, rotation=45, ha='right') # Color best result best_idx = np.argmax(values) if metric != "scaled_diff" else np.argmin(np.abs(values)) bars[best_idx].set_color('green') self.plt.tight_layout() if save_path: self.plt.savefig(save_path, dpi=300, bbox_inches='tight') else: self.plt.show() self.plt.close()
[docs] def plot_cluster_sizes( self, result: Union[ClusterEvaluationResult, EntropyClusterResult], save_path: Optional[str] = None ) -> None: """Plot cluster size distribution. Args: result: Clustering result save_path: Optional path to save the plot """ if not self.plt: return fig, (ax1, ax2) = self.plt.subplots(1, 2, figsize=(12, 5)) # Bar plot of cluster sizes cluster_labels = [f"Cluster {i+1}" for i in range(result.n_clusters)] ax1.bar(cluster_labels, result.counts) ax1.set_xlabel("Cluster") ax1.set_ylabel("Number of Items") ax1.set_title("Cluster Sizes") ax1.axhline(y=np.mean(result.counts), color='r', linestyle='--', label=f'Mean: {np.mean(result.counts):.1f}') ax1.legend() # Pie chart ax2.pie(result.counts, labels=cluster_labels, autopct='%1.1f%%') ax2.set_title("Cluster Size Distribution") self.plt.tight_layout() if save_path: self.plt.savefig(save_path, dpi=300, bbox_inches='tight') else: self.plt.show() self.plt.close()
[docs] def plot_probability_comparison( self, result: ClusterEvaluationResult, cluster_idx: Optional[int] = None, save_path: Optional[str] = None ) -> None: """Plot P vs PHat comparison. Args: result: Evaluation result cluster_idx: Optional specific cluster to highlight save_path: Optional path to save the plot """ if not self.plt: return fig, (ax1, ax2) = self.plt.subplots(1, 2, figsize=(12, 5)) # Scatter plot of P vs PHat ax1.scatter(result.p, result.p_hat, alpha=0.6) ax1.plot([0, 1], [0, 1], 'r--', label='y=x') ax1.set_xlabel('P (Theoretical)') ax1.set_ylabel('PHat (Estimated)') ax1.set_title('P vs PHat Comparison') ax1.legend() ax1.grid(True, alpha=0.3) # Difference histogram differences = result.p_hat - result.p ax2.hist(differences, bins=30, edgecolor='black', alpha=0.7) ax2.axvline(x=0, color='r', linestyle='--', label='Zero difference') ax2.set_xlabel('PHat - P') ax2.set_ylabel('Frequency') ax2.set_title('Distribution of Differences') ax2.legend() self.plt.tight_layout() if save_path: self.plt.savefig(save_path, dpi=300, bbox_inches='tight') else: self.plt.show() self.plt.close()
[docs] def plot_entropy_evolution( self, entropy_values: List[float], save_path: Optional[str] = None ) -> None: """Plot entropy values over iterations. Args: entropy_values: List of entropy values per iteration save_path: Optional path to save the plot """ if not self.plt: return fig, ax = self.plt.subplots(figsize=(10, 6)) iterations = range(len(entropy_values)) ax.plot(iterations, entropy_values, 'b-o', linewidth=2, markersize=6) ax.set_xlabel('Iteration') ax.set_ylabel('Entropy') ax.set_title('Entropy Evolution During Clustering') ax.grid(True, alpha=0.3) # Mark convergence if len(entropy_values) > 1: final_value = entropy_values[-1] ax.axhline(y=final_value, color='r', linestyle='--', label=f'Final: {final_value:.4f}') ax.legend() self.plt.tight_layout() if save_path: self.plt.savefig(save_path, dpi=300, bbox_inches='tight') else: self.plt.show() self.plt.close()
[docs] def create_evaluation_report( self, results: Dict[str, Any], save_path: str ) -> None: """Create a comprehensive evaluation report. Args: results: Dictionary containing various evaluation results save_path: Path to save the report (PDF or image) """ if not self.plt: return # Create figure with subplots fig = self.plt.figure(figsize=(16, 12)) # Title fig.suptitle('SUBMARIT Clustering Evaluation Report', fontsize=16, fontweight='bold') # Create grid for subplots gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3) # Add various plots based on available results plot_idx = 0 if 'cluster_eval' in results: eval_result = results['cluster_eval'] # Cluster sizes ax1 = fig.add_subplot(gs[0, 0]) ax1.bar(range(len(eval_result.counts)), eval_result.counts) ax1.set_title('Cluster Sizes') ax1.set_xlabel('Cluster') ax1.set_ylabel('Count') # P vs PHat ax2 = fig.add_subplot(gs[0, 1]) ax2.scatter(eval_result.p, eval_result.p_hat, alpha=0.5) ax2.plot([0, max(eval_result.p)], [0, max(eval_result.p)], 'r--') ax2.set_title('P vs PHat') ax2.set_xlabel('P') ax2.set_ylabel('PHat') # Metrics summary ax3 = fig.add_subplot(gs[0, 2]) metrics = { 'Diff': eval_result.diff, 'Z-value': eval_result.z_value, 'Log-LH': eval_result.log_lh, 'Scaled Diff': eval_result.scaled_diff } ax3.bar(metrics.keys(), metrics.values()) ax3.set_title('Evaluation Metrics') ax3.tick_params(axis='x', rotation=45) if 'gap_result' in results: gap_result = results['gap_result'] # GAP statistic plot ax4 = fig.add_subplot(gs[1, :2]) ax4.plot(gap_result.k_values, gap_result.gap_values, 'g-^', linewidth=2) ax4.axvline(x=gap_result.best_k, color='r', linestyle='--', label=f'Best k={gap_result.best_k}') ax4.set_title('GAP Statistic') ax4.set_xlabel('Number of Clusters (k)') ax4.set_ylabel('GAP Value') ax4.legend() ax4.grid(True, alpha=0.3) if 'entropy_result' in results: entropy_result = results['entropy_result'] # Entropy metrics ax5 = fig.add_subplot(gs[1, 2]) entropy_metrics = { 'Total': entropy_result.entropy, 'Normalized': entropy_result.entropy_norm, 'Size-scaled': entropy_result.entropy_norm2 } ax5.bar(entropy_metrics.keys(), entropy_metrics.values()) ax5.set_title('Entropy Metrics') ax5.set_ylabel('Entropy') # Add text summary ax_text = fig.add_subplot(gs[2, :]) ax_text.axis('off') summary_text = self._generate_text_summary(results) ax_text.text(0.05, 0.95, summary_text, transform=ax_text.transAxes, fontsize=10, verticalalignment='top', fontfamily='monospace') self.plt.tight_layout() self.plt.savefig(save_path, dpi=300, bbox_inches='tight') self.plt.close()
def _generate_text_summary(self, results: Dict[str, Any]) -> str: """Generate text summary for the report.""" lines = ["SUMMARY", "=" * 50] if 'cluster_eval' in results: eval_result = results['cluster_eval'] lines.extend([ f"Number of clusters: {eval_result.n_clusters}", f"Number of items: {eval_result.n_items}", f"Z-value: {eval_result.z_value:.4f}", f"Log-likelihood: {eval_result.log_lh:.4f}", f"Total difference (PHat-P): {eval_result.diff:.6f}", "" ]) if 'gap_result' in results: gap_result = results['gap_result'] lines.extend([ f"GAP Statistic Analysis:", f" Optimal k: {gap_result.best_k}", f" Criterion: {gap_result.criterion_name}", f" K range tested: [{gap_result.k_values[0]}, {gap_result.k_values[-1]}]", "" ]) if 'entropy_result' in results: entropy_result = results['entropy_result'] lines.extend([ f"Entropy Clustering:", f" Iterations: {entropy_result.n_iter}", f" Total entropy: {entropy_result.entropy:.6f}", f" Max objective: {entropy_result.max_obj:.6f}", "" ]) return "\n".join(lines)