federated-learning-dqn
SKILL.md
Federated Learning + DQN
Privacy-preserving distributed reinforcement learning for healthcare scheduling.
When to Use
- Multi-institution ML without sharing raw data
- Healthcare applications with privacy requirements
- Distributed optimization across organizations
Architecture Overview
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Hospital A │ │ Hospital B │ │ Hospital C │
│ Local DQN │ │ Local DQN │ │ Local DQN │
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
│ │ │
└───────────────────┼───────────────────┘
│
┌──────▼──────┐
│ Aggregator │
│ (Server) │
└─────────────┘
Components
Federated Learning
FedAvg Algorithm:
# Server
def federated_averaging(models, weights):
total = sum(weights)
averaged = {}
for key in models[0].state_dict():
averaged[key] = sum(
w * model.state_dict()[key]
for model, w in zip(models, weights)
) / total
return averaged
# Round
for round in range(num_rounds):
clients = select_clients()
models, weights = [], []
for client in clients:
model, weight = client.train(local_epochs)
models.append(model)
weights.append(weight)
global_model.load_state_dict(federated_averaging(models, weights))
Deep Q-Network (DQN)
Network Architecture:
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim)
)
def forward(self, x):
return self.net(x)
Training Loop:
def train_dqn(agent, replay_buffer, target_net):
for step in range(num_steps):
state = env.reset()
done = False
while not done:
# Epsilon-greedy action
action = agent.select_action(state, epsilon)
next_state, reward, done, _ = env.step(action)
# Store transition
replay_buffer.push(state, action, reward, next_state, done)
# Sample batch
batch = replay_buffer.sample(batch_size)
# Compute loss
q_values = agent(batch.state)
next_q_values = target_net(batch.next_state)
target = batch.reward + gamma * next_q_values.max(1)[0] * (1 - batch.done)
loss = nn.MSELoss()(q_values.gather(1, batch.action), target)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
# Update target network
if step % target_update == 0:
target_net.load_state_dict(agent.state_dict())
Multi-Level Feedback Queue (MLFQ)
Integration with DQN:
class MLFQScheduler:
def __init__(self, num_queues=3):
self.queues = [[] for _ in range(num_queues)]
self.priority_boost = 10
def add_patient(self, patient, priority):
queue_idx = min(priority, len(self.queues) - 1)
self.queues[queue_idx].append(patient)
def get_next_patient(self):
# DQN selects which queue to serve
queue_state = self.get_queue_state()
action = dqn_agent.select_action(queue_state)
# Boost priority of waiting patients
self.boost_priorities()
return self.queues[action].pop(0) if self.queues[action] else None
def boost_priorities(self):
for i in range(len(self.queues) - 1, 0, -1):
for patient in self.queues[i]:
if patient.wait_time > self.priority_boost:
self.queues[i-1].append(patient)
self.queues[i].remove(patient)
Privacy Guarantees
Differential Privacy
def add_dp_noise(gradients, epsilon, delta, sensitivity):
"""Add Gaussian noise for (ε,δ)-differential privacy"""
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
noise = torch.randn_like(gradients) * sigma
return gradients + noise
Secure Aggregation
- Clients encrypt model updates
- Server aggregates without seeing individual updates
- Only decrypted aggregate is visible
Healthcare Scheduling Use Case
State Representation
state = {
'queue_lengths': [len(q) for q in queues], # Shape: (num_queues,)
'patient_acuity': average_acuity_per_queue, # Shape: (num_queues,)
'resource_availability': [beds, staff, equipment],
'time_features': [hour_of_day, day_of_week],
'predicted_arrivals': next_hour_forecast,
}
Action Space
actions = {
0: 'Schedule from high-priority queue',
1: 'Schedule from medium-priority queue',
2: 'Schedule from low-priority queue',
3: 'Allocate additional resource',
4: 'Request transfer from other hospital',
}
Reward Function
def calculate_reward(state, action, next_state):
reward = 0
# Minimize wait time (weighted by acuity)
reward -= sum(
patient.wait_time * patient.acuity
for patient in all_patients
)
# Penalize queue imbalance
reward -= variance(queue_lengths) * 10
# Reward completing high-acuity cases
reward += completed_high_acuity * 50
# Penalize resource overutilization
if resource_utilization > threshold:
reward -= overutilization_penalty
return reward
Implementation Considerations
Communication Efficiency
- Compression: Quantize model updates
- Federated Dropout: Train smaller subnetworks
- Asynchronous Updates: No synchronization barrier
Handling Non-IID Data
- Personalization: Fine-tune global model locally
- Clustered FL: Group similar hospitals
- Multi-task Learning: Shared representation + task-specific heads
System Heterogeneity
- Straggler Handling: Async aggregation or timeout
- Variable Resources: Adaptive local epochs
- Device Selection: Probabilistic client sampling
Evaluation Metrics
| Metric | Description |
|---|---|
| Privacy Budget (ε) | Differential privacy guarantee |
| Model Accuracy | Comparison to centralized training |
| Communication Rounds | Convergence speed |
| Patient Wait Time | Scheduling effectiveness |
| Resource Utilization | System efficiency |
Resources
Weekly Installs
3
Repository
kinhluan/skillsGitHub Stars
1
First Seen
3 days ago
Security Audits
Installed on
amp3
cline3
opencode3
cursor3
kimi-cli3
codex3