skills/kinhluan/skills/federated-learning-dqn

federated-learning-dqn

SKILL.md

Federated Learning + DQN

Privacy-preserving distributed reinforcement learning for healthcare scheduling.

When to Use

  • Multi-institution ML without sharing raw data
  • Healthcare applications with privacy requirements
  • Distributed optimization across organizations

Architecture Overview

┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│  Hospital A │     │  Hospital B │     │  Hospital C │
│  Local DQN  │     │  Local DQN  │     │  Local DQN  │
└──────┬──────┘     └──────┬──────┘     └──────┬──────┘
       │                   │                   │
       └───────────────────┼───────────────────┘
                    ┌──────▼──────┐
                    │  Aggregator │
                    │  (Server)   │
                    └─────────────┘

Components

Federated Learning

FedAvg Algorithm:

# Server
def federated_averaging(models, weights):
    total = sum(weights)
    averaged = {}
    for key in models[0].state_dict():
        averaged[key] = sum(
            w * model.state_dict()[key] 
            for model, w in zip(models, weights)
        ) / total
    return averaged

# Round
for round in range(num_rounds):
    clients = select_clients()
    models, weights = [], []
    for client in clients:
        model, weight = client.train(local_epochs)
        models.append(model)
        weights.append(weight)
    global_model.load_state_dict(federated_averaging(models, weights))

Deep Q-Network (DQN)

Network Architecture:

import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )
    
    def forward(self, x):
        return self.net(x)

Training Loop:

def train_dqn(agent, replay_buffer, target_net):
    for step in range(num_steps):
        state = env.reset()
        done = False
        
        while not done:
            # Epsilon-greedy action
            action = agent.select_action(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            
            # Store transition
            replay_buffer.push(state, action, reward, next_state, done)
            
            # Sample batch
            batch = replay_buffer.sample(batch_size)
            
            # Compute loss
            q_values = agent(batch.state)
            next_q_values = target_net(batch.next_state)
            target = batch.reward + gamma * next_q_values.max(1)[0] * (1 - batch.done)
            loss = nn.MSELoss()(q_values.gather(1, batch.action), target)
            
            # Update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            state = next_state
        
        # Update target network
        if step % target_update == 0:
            target_net.load_state_dict(agent.state_dict())

Multi-Level Feedback Queue (MLFQ)

Integration with DQN:

class MLFQScheduler:
    def __init__(self, num_queues=3):
        self.queues = [[] for _ in range(num_queues)]
        self.priority_boost = 10
        
    def add_patient(self, patient, priority):
        queue_idx = min(priority, len(self.queues) - 1)
        self.queues[queue_idx].append(patient)
    
    def get_next_patient(self):
        # DQN selects which queue to serve
        queue_state = self.get_queue_state()
        action = dqn_agent.select_action(queue_state)
        
        # Boost priority of waiting patients
        self.boost_priorities()
        
        return self.queues[action].pop(0) if self.queues[action] else None
    
    def boost_priorities(self):
        for i in range(len(self.queues) - 1, 0, -1):
            for patient in self.queues[i]:
                if patient.wait_time > self.priority_boost:
                    self.queues[i-1].append(patient)
                    self.queues[i].remove(patient)

Privacy Guarantees

Differential Privacy

def add_dp_noise(gradients, epsilon, delta, sensitivity):
    """Add Gaussian noise for (ε,δ)-differential privacy"""
    sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
    noise = torch.randn_like(gradients) * sigma
    return gradients + noise

Secure Aggregation

  • Clients encrypt model updates
  • Server aggregates without seeing individual updates
  • Only decrypted aggregate is visible

Healthcare Scheduling Use Case

State Representation

state = {
    'queue_lengths': [len(q) for q in queues],  # Shape: (num_queues,)
    'patient_acuity': average_acuity_per_queue,  # Shape: (num_queues,)
    'resource_availability': [beds, staff, equipment],
    'time_features': [hour_of_day, day_of_week],
    'predicted_arrivals': next_hour_forecast,
}

Action Space

actions = {
    0: 'Schedule from high-priority queue',
    1: 'Schedule from medium-priority queue',
    2: 'Schedule from low-priority queue',
    3: 'Allocate additional resource',
    4: 'Request transfer from other hospital',
}

Reward Function

def calculate_reward(state, action, next_state):
    reward = 0
    
    # Minimize wait time (weighted by acuity)
    reward -= sum(
        patient.wait_time * patient.acuity 
        for patient in all_patients
    )
    
    # Penalize queue imbalance
    reward -= variance(queue_lengths) * 10
    
    # Reward completing high-acuity cases
    reward += completed_high_acuity * 50
    
    # Penalize resource overutilization
    if resource_utilization > threshold:
        reward -= overutilization_penalty
    
    return reward

Implementation Considerations

Communication Efficiency

  • Compression: Quantize model updates
  • Federated Dropout: Train smaller subnetworks
  • Asynchronous Updates: No synchronization barrier

Handling Non-IID Data

  • Personalization: Fine-tune global model locally
  • Clustered FL: Group similar hospitals
  • Multi-task Learning: Shared representation + task-specific heads

System Heterogeneity

  • Straggler Handling: Async aggregation or timeout
  • Variable Resources: Adaptive local epochs
  • Device Selection: Probabilistic client sampling

Evaluation Metrics

Metric Description
Privacy Budget (ε) Differential privacy guarantee
Model Accuracy Comparison to centralized training
Communication Rounds Convergence speed
Patient Wait Time Scheduling effectiveness
Resource Utilization System efficiency

Resources

Weekly Installs
3
Repository
kinhluan/skills
GitHub Stars
1
First Seen
3 days ago
Installed on
amp3
cline3
opencode3
cursor3
kimi-cli3
codex3