Post-HCT Survival Prediction – Sprint 3

SPRINT 3

Post-HCT Survival Prediction

Predicting post-HCT event-free survival probability (CIBMTR dataset).

Regression
Problem Type
Stacking
Final Model
R² ≈ 0.18
Best Base Model (XGBoost)

Data Overview & Missing Values

EDA

Dataset

  • Train: 28,800 rows, 60 columns.
  • Test: 28,800 rows, 58 columns.
  • Outcome: efs (event), efs_time (months).

Highest Missing Counts

  • cyto_score: 8,068
  • hla_high_res_10: 7,163
  • hla_high_res_8: 5,829
  • hla_match_dqb1_high: 5,199
  • conditioning_intensity: 4,789

Simple rules: fill numeric gaps with typical values, fill missing categories with common labels, drop strongly redundant HLA summary fields.

Target: Survival Probability

TARGET

Idea

  • Use efs_time and efs.
  • Fit Kaplan–Meier curve on the full training set.
  • For each patient: survival probability at their own time.
  • Store this as continuous target survival.

This turns the problem into a regression task on survival probability, not a 0/1 classification.

from lifelines import KaplanMeierFitter

def transform_survival(df,time_col='efs_time',event_col='efs'):
    kfm=KaplanMeierFitter()
    kfm.fit(df[time_col],df[event_col])
    y=kfm.survival_function_at_times(df[time_col]).values
    return y


train_data['survival']=transform_survival(train_data)

Model & Evaluation (6-Fold CV)

MODELING

What We Do

  • Split processed data into train / validation.
  • Base models (regression): Linear, Random Forest, XGBoost.
  • 6-fold cross-validation on training data.
  • Metric: (scikit-learn default for regressors).
  • Visualize R² per fold (bar plot + mean line).
  • Train a StackingRegressor on all training data.

The better the R², the more variance in survival probability the model can explain.

def model_to_stack(): 
    kf=KFold(n_splits=6,shuffle=True)
    base_models = [
    ('lr', LinearRegression()),  
    ('rf', RandomForestRegressor(n_estimators=100, random_state=42)),  
    ('xgb', XGBRegressor(n_estimators=100, learning_rate=0.1, random_state=42))  
    ]
    for name,model in base_models:
        scores=cross_val_score(model,train_inputs,target_inputs,cv=kf)
        plt.figure(figsize=(10,7))
        sns.barplot(x=np.arange(1,len(scores)+1),y=scores)
        plt.axhline(scores.mean(),linestyle='--',c='r')
        plt.title(f"performance{name}")
        plt.xlabel("Folds")
        plt.ylabel("Score")
        plt.show()

    stacking_model=StackingRegressor(estimators=base_models,final_estimator=LinearRegression(),cv=5)
    stacking_model.fit(train_inputs,target_inputs)
    return stacking_model.predict(test_en)
    print(f"Final Prediction:{final_prediction}")

Evaluation Workflow (Only)

EVALUATION
Evaluation workflow diagram

Next Steps

ROADMAP

Better Features

  • Refine risk encodings (e.g. HLA, comorbidities).
  • Group rare categories.
  • Test interaction features.

Better Models

  • Hyperparameter tuning for RF/XGBoost.
  • Try survival-specific models (e.g. Cox, DeepSurv).
  • Compare to current stacking baseline.

Better Metrics

  • Add C-index for survival ranking.
  • Check calibration of predicted survival.
  • Consider time-dependent AUC / Brier score.
Scroll to Top