model-extraction-relu-logits
Model Extraction for ReLU Networks
This skill provides guidance for extracting internal weight matrices from black-box ReLU neural networks using only input-output access.
Problem Understanding
Model extraction tasks typically involve:
- A black-box neural network that accepts inputs and returns outputs (logits)
- The goal of recovering internal parameters (weight matrices, biases)
- No direct access to the network's implementation or internal state
Critical Principle: True Black-Box Treatment
Treat the target network as a genuine black-box. Never rely on implementation details that may change during evaluation:
- Do not hardcode hidden layer dimensions from example code
- Do not assume specific random seeds or initialization schemes
- Do not directly compare extracted weights to "true" weights read from source files
- The test environment may use completely different parameters than any provided examples
Approach Selection
Understanding ReLU Network Structure
A two-layer ReLU network computes: output = A2 @ ReLU(A1 @ x + b1) + b2
Key properties to exploit:
- Piecewise linearity: ReLU networks are piecewise linear functions
- Activation boundaries: Each hidden neuron creates a hyperplane boundary where its output transitions from zero to active
- Gradient structure: In each linear region, the gradient reveals information about active neurons
Recommended Extraction Strategies
Strategy 1: Critical Point Analysis
ReLU networks have critical points where neurons transition between active/inactive states:
- Probe the network systematically to identify transition boundaries
- At each boundary, a hyperplane normal corresponds to a row of A1
- Collect enough boundaries to reconstruct A1
Strategy 2: Gradient-Based Extraction
For networks where gradients are accessible or can be approximated:
- Query gradients at multiple random points
- Gradients in a linear region reveal which neurons are active
- Use gradient information to identify weight matrix rows
Strategy 3: Activation Pattern Enumeration
Systematically identify which neurons are active in different input regions:
- Start from a known point and identify its activation pattern
- Search for inputs that cause different neurons to activate
- Use the transition points to extract hyperplane parameters
Strategy 4: Optimization-Based Fitting (Fallback)
When mathematically principled methods are insufficient:
- Generate diverse input-output pairs from the black-box
- Train a surrogate network to match outputs
- Critical: Make network capacity adaptive (try multiple hidden dimensions)
- Validate by output matching, not parameter comparison
Hidden Dimension Discovery
Since the hidden dimension is unknown, employ detection strategies:
- Rank analysis: The output dimension and response complexity bound hidden size
- Binary search: Try different hidden sizes and measure reconstruction error
- Overcomplete fitting: Use larger hidden dimension than necessary, then identify redundant neurons
- Gradient counting: In a fixed input region, count distinct gradient patterns
Verification Strategy
Correct Verification (Functional Equivalence)
# Generate test inputs NOT used during extraction
test_inputs = generate_diverse_inputs(n=1000)
# Compare outputs
original_outputs = [black_box_query(x) for x in test_inputs]
extracted_outputs = [extracted_model(x) for x in test_inputs]
# Check functional equivalence
max_error = max(|original - extracted| for all test points)
assert max_error < tolerance
Incorrect Verification (Avoid These)
- Comparing extracted weights directly to weights read from source files
- Using the same inputs for extraction and verification
- Relying on cosine similarity to "true" parameters
- Checking only a small number of test points
Common Pitfalls
1. Peeking at Implementation Details
Problem: Reading source code to get the "true" weights or hidden dimension, then validating against them.
Why it fails: Test environments often use different parameters (different seeds, dimensions, scales).
Solution: Treat extraction as if source code doesn't exist. Validate only through output comparison.
2. Hardcoding Network Architecture
Problem: Assuming hidden dimension is fixed (e.g., n_neurons=20).
Why it fails: The actual network may have a different architecture.
Solution: Either detect hidden dimension empirically or design extraction to work with unknown dimensions.
3. Non-Unique Solutions
Problem: Many weight configurations produce identical input-output behavior.
Why it fails: Optimization may find a valid equivalent representation, not the original weights.
Solution: If the task requires recovering specific original weights (not just functional equivalents), use mathematically principled extraction that exploits ReLU structure.
4. Insufficient Test Coverage
Problem: Verifying on a few hand-picked inputs.
Why it fails: The extracted model may fail on untested input regions.
Solution: Use comprehensive random testing across the input domain, including edge cases.
5. Numerical Precision Issues
Problem: Accumulated floating-point errors cause extraction to fail.
Solution: Use numerically stable algorithms, appropriate tolerances, and verify with realistic precision expectations.
Implementation Checklist
Before declaring success, verify:
- No implementation details (seeds, dimensions) were read from source files
- Hidden dimension was detected or handled adaptively
- Verification uses only input-output comparisons
- Verification inputs are independent from extraction inputs
- Sufficient test coverage (hundreds to thousands of points)
- Error tolerance is appropriate for the task requirements
- The extracted model works as a functional replacement
When Standard Approaches Fail
If initial extraction attempts fail:
- Increase probe density: More input-output pairs may be needed
- Try multiple hidden dimensions: The assumed size may be wrong
- Check for numerical issues: Scaling, precision, or conditioning problems
- Verify the network structure: Ensure assumptions about architecture (two-layer, ReLU) are correct
- Consider alternative representations: Some equivalent parameterizations may be easier to extract