Machine Unlearning: The Right to be Forgotten
Abstract
"Compliance Gridlock" occurs when a single user invokes their Right to Erasure (GDPR Article 17), forcing an engineering team to decide between two fatal options: ignore the law (regulatory risk) or retrain a massive model from scratch to remove one row of data (financial ruin). In the era of billion-parameter models, full retraining for every deletion request is economically impossible. This post introduces Machine Unlearning, specifically the SISA (Sharded, Isolated, Sliced, Aggregated) architecture. By restructuring how models consume data, we turn an O(N) retraining problem into an O(N/k) problem, enabling exact unlearning with fractional compute costs while maintaining statistical guarantees of deletion.
1. Why This Topic Matters
Data deletion is no longer just a database operation; it is a machine learning operation.
When a user deletes their account, removing their rows from PostgreSQL is easy. But if their data was used to train a neural network, their "digital shadow" persists in the weights.
- The Legal Reality: Regulators (EU GDPR, California DELETE Act) increasingly view model weights as "personal data" if the model can memorize specific training examples.
- The Economic Reality: If retraining your core model costs $50,000 and takes 3 days, and you receive 10 deletion requests a week, you are mathematically bankrupt.
- The Security Reality: Membership Inference Attacks (MIAs) can prove whether a specific user's data was used, making "pretend deletion" (ignoring the weights) auditable and punishable.
2. Core Concepts & Mental Models
The Unlearning Spectrum
- Exact Unlearning: The resulting model is mathematically identical to one trained without the data. (Gold Standard).
- Approximate Unlearning: Modifying weights (e.g., gradient ascent on the deleted point) to "scrub" the influence. (Faster, but no theoretical guarantee the data is truly gone).
The SISA Architecture
To achieve Exact Unlearning without full retraining, we divide and conquer.
- Sharding: Split the training data
Dintokdisjoint shards. - Isolation: Train a separate sub-model on each shard.
- Aggregation: At inference time, the sub-models vote (majority vote or average logits) to produce the final prediction.
The "Undo" Button
When a user (located in Shard 2) requests deletion:
- Delete their data from Shard 2.
- Retrain only the Shard 2 model.
- All other shard models remain untouched.
Result: We save roughly (k-1)/k of the compute cost. For 3 shards, we save 66%. For 10 shards, 90%.
3. Theoretical Foundations
The Performance vs. Efficiency Trade-off
SISA is an ensemble method.
- Benefit: Ensembles often generalize well.
- Cost: Each sub-model sees only
1/kof the data. Ifkis too large, individual models become weak learners, and accuracy drops. - Optimization: Slicing. Within each shard, we train incrementally and save checkpoints. If a data point is in the last "slice," we only need to resume training from the second-to-last checkpoint, further reducing cost.
4. Production-Grade Implementation
The "Unlearning Service" Wrapper
We wrap the SISA ensemble in a standard inference class.
- Mapper: A persistent key-value store (Redis) maps
UserID -> ShardID. - Orchestrator: When
DELETE /user/{id}is called:- Lookup ShardID.
- Trigger async retraining job for that specific Shard.
- Hot-swap the new Shard model weights upon completion.
- Inference: Parallelize calls to all shard models and aggregate.
5. Hands-On Project / Exercise
Goal: Implement a 3-Shard SISA framework for a classifier.
Constraint: Demonstrate that unlearning one sample triggers only 33% of the total training workload.
Setup
We use scikit-learn for simplicity (Support Vector Machines), but the logic applies identically to Deep Learning (PyTorch/TensorFlow).
# pip install scikit-learn numpy
import numpy as np
import time
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_classification
# --- 1. SISA Architecture Definition ---
class SISAEnsemble:
def __init__(self, num_shards=3):
self.num_shards = num_shards
self.shards = [] # List of {model, data_indices}
self.X_train = None
self.y_train = None
def fit(self, X, y):
self.X_train = X
self.y_train = y
n_samples = len(X)
# A. Sharding: Randomly partition data indices
indices = np.arange(n_samples)
np.random.shuffle(indices)
shard_size = n_samples // self.num_shards
print("--- TRAINING START (Full Build) ---")
start_global = time.time()
for i in range(self.num_shards):
start = i * shard_size
end = (i + 1) * shard_size if i < self.num_shards - 1 else n_samples
shard_idx = indices[start:end]
# B. Isolation: Train separate model per shard
print(f"Training Shard {i} with {len(shard_idx)} samples...")
model = SVC(probability=True, random_state=42)
model.fit(X[shard_idx], y[shard_idx])
self.shards.append({
"model": model,
"indices": list(shard_idx) # Mutable list for deletion
})
print(f"Total Training Time: {time.time() - start_global:.4f}s")
def predict(self, X):
# C. Aggregation: Soft Voting (Average Probabilities)
all_probs = [shard["model"].predict_proba(X) for shard in self.shards]
avg_probs = np.mean(all_probs, axis=0)
return np.argmax(avg_probs, axis=1)
def unlearn(self, data_index_to_remove):
"""
The core mechanism: Retrain ONLY the affected shard.
"""
print(f"\n--- UNLEARNING REQUEST: Index {data_index_to_remove} ---")
start_unlearn = time.time()
# 1. Find the shard containing the data point
target_shard_id = next(
(i for i, shard in enumerate(self.shards)
if data_index_to_remove in shard["indices"]),
-1
)
if target_shard_id == -1:
print("Data point not found (already deleted?).")
return
# 2. Remove index from shard metadata
self.shards[target_shard_id]["indices"].remove(data_index_to_remove)
current_indices = self.shards[target_shard_id]["indices"]
# 3. Retrain ONLY that shard
print(f"Re-training ONLY Shard {target_shard_id} (Size: {len(current_indices)})...")
new_model = SVC(probability=True, random_state=42)
new_model.fit(self.X_train[current_indices], self.y_train[current_indices])
# 4. Hot Swap
self.shards[target_shard_id]["model"] = new_model
print(f"Unlearning Complete. Time: {time.time() - start_unlearn:.4f}s")
print(f"Compute Savings: {(self.num_shards - 1) / self.num_shards:.0%}")
# --- 2. Simulation ---
X, y = make_classification(n_samples=3000, n_features=20, random_state=42)
sisa = SISAEnsemble(num_shards=3)
sisa.fit(X, y)
# Baseline Accuracy
y_pred = sisa.predict(X)
print(f"Initial Accuracy: {accuracy_score(y, y_pred):.4f}")
# Request Deletion of a specific sample (e.g., Index 100)
# This simulates a user invoking GDPR Art. 17
sisa.unlearn(data_index_to_remove=100)
# Verify System Integrity
y_pred_new = sisa.predict(X)
print(f"Post-Deletion Accuracy: {accuracy_score(y, y_pred_new):.4f}")
Verification Strategy
- Time: The "Unlearning Time" should be roughly 1/3rd of the "Total Training Time" (assuming sequential training; in parallel training, it saves total compute resources/dollars even if wall-clock is similar).
- Membership Inference: Run a loss check on index
100specifically on the shard where it used to reside — the loss should increase significantly (moving from "training set" to "test set" distribution).
6. Ethical, Security & Safety Considerations
The "Streisand Effect" of Unlearning
If you use Approximate Unlearning (e.g., gradient ascent), you might accidentally create a "negative memory." The model might not just forget the user; it might aggressively learn to not be the user, creating an anomaly that reveals the user existed. SISA avoids this by retraining from scratch, providing Exact Unlearning.
Privacy vs. Utility
Splitting data into 100 shards makes unlearning instant (1% cost), but the model will likely fail to learn complex patterns because no single sub-model sees enough data.
- Safety Check: Monitor the ensemble accuracy. If it drops below a threshold, you must merge shards (reduce
k) and accept higher unlearning costs.
7. Business & Strategic Implications
- Cost Cap: SISA creates a predictable upper bound on compliance costs. You know exactly how much CPU/GPU a deletion request consumes.
- Tiered Service: You can charge enterprise customers for "Instant Unlearning" (smaller shards, higher cost) vs. "Standard Unlearning" (batch processing deletions once a week).
- Audit Trail: SISA provides a clean audit log. "User X was in Shard 4. Shard 4 was retrained at timestamp T. Hash of new weights is H."
8. Common Pitfalls & Misconceptions
-
Pitfall: Forgetting Shared Layers.
- Correction: If you use a pre-trained feature extractor (e.g., BERT) and only SISA-train the classification head, the user's data might still be encoded in the fine-tuned embeddings of the base model. For full compliance, the entire trainable stack must be sharded, or the base model must remain frozen/public.
-
Pitfall: Data Duplication.
- Correction: SISA requires disjoint shards. If a user's data is duplicated across shards, you defeat the purpose (you'd have to retrain multiple shards).
-
Pitfall: Ignoring Logs.
- Correction: Unlearning weights is useless if the raw data persists in your
access.logor S3 backup buckets.
- Correction: Unlearning weights is useless if the raw data persists in your
9. Prerequisites & Next Steps
Prerequisites:
- Understanding of Ensemble Learning (Bagging).
- Basic Data Structures (Hash Maps for indexing).
Next Steps:
- Implement Slicing: Add checkpointing within shards to further reduce retraining time.
- Parallelize: Use
jobliborRayto train shards in parallel. - Policy: Define a "Retraining Cadence." Maybe you don't retrain immediately; you mark the shard as "dirty" and retrain when it accumulates 50 deletions.
With unlearning in place, every individual compliance mechanism from Days 51–59 is operational. The final step is to wire them together. Day 60: The White Box Capstone: The Audit Defense synthesizes explainability, fairness, provenance, and unlearning into a single production API that can satisfy a regulator's challenge in real time.
10. Further Reading & Resources
- Paper: "Machine Unlearning" (Bourtoule et al., 2021) – The original SISA paper.
- Regulation: GDPR Article 17 (Right to Erasure).
- Library: Microsoft/unlearning-survey – Curated research on machine unlearning.
- Concept: Visualizing how shards and slices interact to minimize compute.