Commit 7c226bf6 authored by mohamadbashar.disoki's avatar mohamadbashar.disoki

v1

parents
# ✅ sMRI Data Fix Complete
## What Was Wrong
- **Root cause**: Incorrect `skiprows` values in `load_smri_data()`
- FreeSurfer `.stats` files use a complex header with comment lines starting with `#`
- Previous code used hardcoded `skiprows=61` and `skiprows=79` which didn't match the actual file structure
- This caused **zero sMRI features** to be loaded (0/871 subjects)
## What Was Fixed
Rewrote `load_smri_data()` to:
1. **Parse ColHeaders dynamically** — finds the `# ColHeaders` line in each file
2. **Extract proper column names** — reads "StructName NumVert SurfArea GrayVol..." from the comment line
3. **Skip to data correctly** — starts reading data on the line immediately after ColHeaders
4. **Handle all file types** — properly parses `lh.aparc.stats`, `rh.aparc.stats`, `aseg.stats`, `wmparc.stats`
## Verification Results
```
✓ Loaded sMRI data shape: (871, 842)
✓ sMRI features per subject - min: 0, max: 842, mean: 841.0
✓ Subjects with >=1 feature: 870/871
✓ sMRI non-zero fraction: 0.9843 (98.4% of data is meaningful)
```
**Before fix:** 0/871 subjects had sMRI features (0 columns)
**After fix:** 870/871 subjects have sMRI features (842 columns)
## Features Extracted
- **Cortical (aparc)**: NumVert, SurfArea, GrayVol, ThickAvg, ThickStd, MeanCurv, GausCurv, FoldInd, CurvInd
- From both left and right hemispheres
- **Subcortical (aseg)**: Volume_mm3, NVoxels
- **White matter (wmparc)**: Volume_mm3, NVoxels
Total: 842 sMRI features per subject (9 cortical measures × 2 hemispheres × ~34 regions + 2 subcortical measures × 2 files)
## Expected Impact
With sMRI data now properly loaded and meaningful:
- **Previous AUC (sMRI missing)**: 0.54 ± 0.05
- **Expected AUC (with sMRI)**: **0.65–0.75+**
- Conservative estimate: +0.10–0.15 AUC improvement
- sMRI has strong discriminative signal for ASD classification
## Next Steps
1. **Retrain the model** now with proper multimodal data:
```bash
python train_braingnn.py
```
2. **Monitor improvements**:
- Watch for AUC rising above 0.60 on validation set
- sMRI branch should learn meaningful features now
3. **(Optional) Apply Phase 2–3 fixes from CODE_REVIEW_FINDINGS.md**:
- Disable auxiliary losses: set `lambda_site=0.0, lambda_age=0.0`
- Add preprocessing: StandardScaler for feature normalization
- Reduce model: halve `hidden_dim` to prevent overfitting
This diff is collapsed.
This diff is collapsed.
#!/usr/bin/env python
"""
Diagnose FreeSurfer stats file format to fix the loader
"""
import os
import pandas as pd
import numpy as np
from pathlib import Path
# Pick a subject with stats files
subject_dir = Path('./data/sMRI/freesurfer_stats/50003')
print("="*70)
print("FreeSurfer Stats Format Diagnosis")
print("="*70)
# Test each file type
test_files = {
'aseg.stats': list(range(0, 100, 5)), # Try different skiprows
'lh.aparc.stats': list(range(0, 100, 5)),
'wmparc.stats': list(range(0, 100, 5)),
}
for filename, skiprows_list in test_files.items():
filepath = subject_dir / filename
if not filepath.exists():
print(f"\n⚠ {filename} NOT FOUND")
continue
print(f"\n{'='*70}")
print(f"Testing: {filename}")
print(f"{'='*70}")
# Try to read with different skiprows
for skiprows in skiprows_list:
try:
# Read just the header
df = pd.read_table(filepath, sep=r'\s+', skiprows=skiprows, nrows=1)
if len(df) > 0 and len(df.columns) > 2: # Good header
print(f"\n✓ GOOD FORMAT at skiprows={skiprows}")
print(f" Columns ({len(df.columns)}): {list(df.columns)[:15]}...")
# Now read actual data
df_full = pd.read_table(filepath, sep=r'\s+', skiprows=skiprows)
print(f" Rows: {len(df_full)}")
print(f" Numeric columns: {df_full.select_dtypes(include=[np.number]).shape[1]}")
# Sample data
print(f" First row:\n{df_full.iloc[0]}")
break
except Exception as e:
pass
print()
# Now test what features we can extract
print("\n" + "="*70)
print("Extracting Sample Features")
print("="*70)
try:
# aseg.stats
aseg_file = subject_dir / 'aseg.stats'
df = pd.read_table(aseg_file, sep=r'\s+', skiprows=4)
print(f"\naseg.stats columns: {list(df.columns)}")
print(f"Sample volume measurements:")
for idx, row in df.head(3).iterrows():
print(f" {row[0]}: {row[3]} mm³ (if column 3 is volume)")
except Exception as e:
print(f"Error reading aseg: {e}")
try:
# lh.aparc.stats
aparc_file = subject_dir / 'lh.aparc.stats'
df = pd.read_table(aparc_file, sep=r'\s+', skiprows=2)
print(f"\nlh.aparc.stats columns: {list(df.columns)}")
print(f"Sample cortical measures:")
for idx, row in df.head(3).iterrows():
print(f" {row[0]}: Area={row[3]} mm², ThickAvg={row[4]} mm (approx)")
except Exception as e:
print(f"Error reading aparc: {e}")
print("\n" + "="*70)
print("Recommendations:")
print("="*70)
print("""
Based on the format found above, update load_smri_data() with correct:
1. skiprows values (usually 2-4 for aparc, 4 for aseg)
2. Column indices to extract
Common columns to extract:
aseg.stats: Volume (column 3)
aparc stats: NumVert, SurfArea, GrayVol, ThickAvg, ThickStd, MeanCurv, GausCurv, FoldInd, CurvInd
""")
# Core Deep Learning
torch>=2.0.0
torchvision>=0.15.0
# Scientific Computing
numpy>=1.21.0
scipy>=1.7.0
pandas>=1.3.0
# Machine Learning
scikit-learn>=1.0.0
# Visualization
matplotlib>=3.4.0
seaborn>=0.11.0
# Progress Bars
tqdm>=4.62.0
# Optional: For advanced features
# torch-geometric>=2.0.0 # For advanced graph operations
# tensorboard>=2.8.0 # For training visualization
# wandb>=0.12.0 # For experiment tracking
{
"config": {
"num_samples": 871,
"num_nodes": 200,
"smri_dim": 842,
"num_sites": 20,
"hidden_dim": 128,
"dropout": 0.5,
"batch_size": 32,
"learning_rate": 0.0005,
"weight_decay": 0.01,
"epochs": 200,
"patience": 30,
"lambda_cls": 1.0,
"k_fold": 5,
"random_seed": 42,
"lambda_site": 0.0,
"lambda_age": 0.0,
"lambda_reg": 0.001,
"save_dir": "./results",
"data_dir": "./data"
},
"fold_results": [
{
"fold": 1,
"test_accuracy": 0.5792349726775956,
"test_auc": 0.6030927835051547,
"test_f1": 0.5752463150491635
},
{
"fold": 2,
"test_accuracy": 0.6440677966101694,
"test_auc": 0.6759590792838874,
"test_f1": 0.6437258719215245
},
{
"fold": 3,
"test_accuracy": 0.603448275862069,
"test_auc": 0.6261781494756404,
"test_f1": 0.6035929307771914
},
{
"fold": 4,
"test_accuracy": 0.4470588235294118,
"test_auc": 0.6101623740201567,
"test_f1": 0.2762314681970349
},
{
"fold": 5,
"test_accuracy": 0.49700598802395207,
"test_auc": 0.5030434782608696,
"test_f1": 0.4975514557928677
}
],
"average_metrics": {
"accuracy": 0.5541631713406396,
"auc": 0.6036871729091418,
"f1": 0.5192696083475564
}
}
\ No newline at end of file
This diff is collapsed.
#!/usr/bin/env python
"""
Quick dry-run test to validate data loading and model shapes
"""
import os
import sys
import json
import numpy as np
import scipy.io as scio
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Import training utilities
from train_braingnn import (
load_phenotypic_data, load_fmri_data, load_smri_data,
create_site_aware_splits, ABIDEDataset, set_seed
)
from braingnn_multimodal import create_model
import torch
from torch.utils.data import DataLoader
def test_data_loading():
"""Test data loading pipeline"""
logger.info("\n" + "="*60)
logger.info("TEST 1: DATA LOADING")
logger.info("="*60)
set_seed(42)
config = {
'num_samples': 871,
'num_nodes': 200,
'smri_dim': 2500,
'data_dir': './data'
}
# Load phenotypic data
logger.info("Loading phenotypic data...")
pheno_data = load_phenotypic_data(
os.path.join(config['data_dir'], 'phynotypic'),
config['num_samples'],
logger
)
# Load fMRI data
logger.info("Loading fMRI data...")
fmri_data = load_fmri_data(
os.path.join(config['data_dir'], 'fMRI', 'CC200'),
pheno_data['subject_IDs'],
config['num_nodes'],
logger
)
# Load sMRI data
logger.info("Loading sMRI data...")
smri_data = load_smri_data(
os.path.join(config['data_dir'], 'sMRI', 'freesurfer_stats'),
pheno_data['subject_IDs'],
logger
)
logger.info(f"\nData shapes:")
logger.info(f" fMRI: {fmri_data.shape}")
logger.info(f" sMRI: {smri_data.shape}")
logger.info(f" Labels: {pheno_data['labels'].shape}")
logger.info(f" Subject IDs: {len(pheno_data['subject_IDs'])}")
# Data validation
logger.info(f"\nData quality checks:")
fmri_nonzero = np.mean(np.abs(fmri_data) > 1e-6)
smri_nonzero = np.mean(np.abs(smri_data) > 1e-6)
logger.info(f" fMRI non-zero fraction: {fmri_nonzero:.4f}")
logger.info(f" sMRI non-zero fraction: {smri_nonzero:.4f}")
logger.info(f" Label balance: ASD={np.sum(pheno_data['labels']==1)}, TD={np.sum(pheno_data['labels']==0)}")
if smri_nonzero < 0.01:
logger.warning("WARNING: sMRI data is mostly zeros or not loaded!")
return fmri_data, smri_data, pheno_data
def test_model_forward(fmri_data, smri_data, pheno_data):
"""Test model forward pass"""
logger.info("\n" + "="*60)
logger.info("TEST 2: MODEL FORWARD PASS")
logger.info("="*60)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Create model
config = {
'num_nodes': 200,
'smri_dim': smri_data.shape[1],
'num_sites': len(np.unique(pheno_data['sites'])),
'hidden_dim': 256,
'dropout': 0.3
}
model = create_model(config).to(device)
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Test with first batch
batch_size = 4
fmri_batch = torch.FloatTensor(fmri_data[:batch_size]).to(device)
smri_batch = torch.FloatTensor(smri_data[:batch_size]).to(device)
site_batch = torch.randint(0, config['num_sites'], (batch_size,)).to(device)
age_batch = torch.FloatTensor(np.random.randn(batch_size, 1)).to(device)
gender_batch = torch.randint(0, 2, (batch_size,)).to(device)
fiq_batch = torch.FloatTensor(np.random.randn(batch_size, 1)).to(device)
logger.info(f"\nBatch shapes:")
logger.info(f" fMRI: {fmri_batch.shape}")
logger.info(f" sMRI: {smri_batch.shape}")
logger.info(f" site: {site_batch.shape}")
logger.info(f" age: {age_batch.shape}")
logger.info(f" gender: {gender_batch.shape}")
logger.info(f" fiq: {fiq_batch.shape}")
# Forward pass
logger.info("\nRunning forward pass...")
try:
with torch.no_grad():
class_logits, site_logits, age_pred, attn_dict = model(
fmri_batch, smri_batch, site_batch, age_batch, gender_batch, fiq_batch
)
logger.info(f"Output shapes:")
logger.info(f" class_logits: {class_logits.shape}")
logger.info(f" site_logits: {site_logits.shape}")
logger.info(f" age_pred: {age_pred.shape}")
logger.info("✓ Forward pass successful!")
return True
except Exception as e:
logger.error(f"✗ Forward pass failed: {e}")
import traceback
traceback.print_exc()
return False
def test_single_batch_training(fmri_data, smri_data, pheno_data):
"""Test single batch training step"""
logger.info("\n" + "="*60)
logger.info("TEST 3: SINGLE BATCH TRAINING")
logger.info("="*60)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(42)
# Prepare data
sites_raw = pheno_data['sites']
sites_clean = []
for s in sites_raw:
s_str = str(s).strip().replace("'", "").replace('"', "").replace('[', '').replace(']', '')
sites_clean.append(s_str)
pheno_data['sites'] = np.array(sites_clean)
unique_sites = np.unique(pheno_data['sites'])
site_to_idx = {site: idx for idx, site in enumerate(unique_sites)}
# Create dataset
indices = list(range(min(16, len(pheno_data['labels'])))) # Use first 16 samples
dataset = ABIDEDataset(
fmri_data, smri_data, pheno_data['labels'],
pheno_data['sites'], pheno_data['ages'],
pheno_data['genders'], pheno_data['fiqs'],
indices, site_to_idx, augment=False
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
# Create model
config = {
'num_nodes': 200,
'smri_dim': smri_data.shape[1],
'num_sites': len(unique_sites),
'hidden_dim': 128, # Smaller for test
'dropout': 0.3
}
model = create_model(config).to(device)
# Optimizer and loss
import torch.optim as optim
from train_braingnn import MultiTaskLoss
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = MultiTaskLoss(lambda_cls=1.0, lambda_site=0.1, lambda_age=0.05, lambda_reg=0.001)
logger.info("Running training step...")
try:
model.train()
for batch_idx, batch in enumerate(dataloader):
logger.info(f"\nBatch {batch_idx}:")
# Move data to device
fmri = batch['fmri'].to(device)
smri = batch['smri'].to(device)
labels = batch['label'].to(device)
sites = batch['site'].to(device)
ages = batch['age'].to(device)
genders = batch['gender'].to(device)
fiqs = batch['fiq'].to(device)
logger.info(f" Batch shapes: fMRI={fmri.shape}, sMRI={smri.shape}, labels={labels.shape}")
# Forward
class_logits, site_logits, age_pred, _ = model(
fmri, smri, sites, ages, genders, fiqs
)
logger.info(f" Logits shapes: class={class_logits.shape}, site={site_logits.shape}")
# Loss and backward
loss, loss_dict = criterion(
class_logits, site_logits, age_pred, labels, sites, ages, model
)
logger.info(f" Loss: total={loss.item():.4f}, cls={loss_dict['cls_loss']:.4f}, site={loss_dict['site_loss']:.4f}")
optimizer.zero_grad()
loss.backward()
optimizer.step()
logger.info(f" ✓ Training step {batch_idx} successful!")
if batch_idx == 0: # Just test one batch
break
logger.info("\n✓ Single batch training successful!")
return True
except Exception as e:
logger.error(f"✗ Training failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
logger.info("="*60)
logger.info("DRY-RUN TEST SUITE")
logger.info("="*60)
try:
# Test 1: Data loading
fmri_data, smri_data, pheno_data = test_data_loading()
# Test 2: Model forward pass
success_forward = test_model_forward(fmri_data, smri_data, pheno_data)
# Test 3: Single batch training
if success_forward:
success_training = test_single_batch_training(fmri_data, smri_data, pheno_data)
logger.info("\n" + "="*60)
logger.info("DRY-RUN COMPLETE")
logger.info("="*60)
except Exception as e:
logger.error(f"Dry-run failed with error: {e}")
import traceback
traceback.print_exc()
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment