Model Extraction Prevention: Detecting and Blocking Model Stealing Through API Queries
Problem
Model extraction (model stealing) is an attack where an adversary queries a production ML API systematically to reconstruct a functionally equivalent copy of the model. The attacker sends carefully chosen inputs, collects outputs (predictions, probabilities, embeddings), and trains a surrogate model that replicates the target’s behaviour. A stolen model gives the attacker free inference (bypassing API costs), the ability to find adversarial examples offline, and access to proprietary capabilities without licensing.
The attack does not require exploiting a vulnerability. It uses the model’s own API exactly as designed. Every prediction returned is a training sample for the attacker’s surrogate. With modern distillation techniques, a few thousand queries can extract a high-fidelity copy of many classification and regression models. For LLMs, systematic querying can extract fine-tuning data, alignment preferences, and decision boundaries.
Rate limiting alone is insufficient. Sophisticated attackers spread queries across time, rotate API keys, and use diverse input distributions that look like normal usage.
Threat Model
- Adversary: Competitor, researcher, or attacker with legitimate API access (valid API key, free tier account, or compromised credentials).
- Objective: Create a functionally equivalent model without training costs. Map decision boundaries for adversarial example generation. Extract proprietary fine-tuning or alignment data from LLMs.
- Blast radius: Loss of intellectual property. Competitor deploys equivalent capability. Attacker uses extracted model to craft adversarial inputs that transfer to the production model.
Configuration
Query Pattern Detection
# query_pattern_detector.py - detect systematic probing of model decision boundaries
import numpy as np
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List, Optional
import time
@dataclass
class QueryProfile:
api_key: str
queries: List[dict] = field(default_factory=list)
timestamps: List[float] = field(default_factory=list)
input_embeddings: List[np.ndarray] = field(default_factory=list)
class ExtractionDetector:
"""
Detect model extraction attempts by analysing query patterns.
Extraction attacks exhibit distinct statistical signatures:
- High query volume with systematic input variation
- Inputs clustered around decision boundaries
- Low entropy in query distribution (not random usage)
- Requests for full probability distributions (not just top-1)
"""
def __init__(self, window_seconds: int = 3600, boundary_threshold: float = 0.1):
self.profiles = defaultdict(lambda: QueryProfile(api_key=""))
self.window = window_seconds
self.boundary_threshold = boundary_threshold
def record_query(self, api_key: str, input_data: dict,
output_probs: np.ndarray, embedding: Optional[np.ndarray] = None):
profile = self.profiles[api_key]
profile.api_key = api_key
profile.queries.append(input_data)
profile.timestamps.append(time.time())
if embedding is not None:
profile.input_embeddings.append(embedding)
# Prune old entries outside the window
cutoff = time.time() - self.window
valid = [i for i, t in enumerate(profile.timestamps) if t > cutoff]
profile.queries = [profile.queries[i] for i in valid]
profile.timestamps = [profile.timestamps[i] for i in valid]
profile.input_embeddings = [profile.input_embeddings[i] for i in valid if i < len(profile.input_embeddings)]
def check_boundary_probing(self, api_key: str, output_probs: np.ndarray) -> float:
"""
Detect decision boundary probing.
Extraction attacks often query near boundaries where the model is uncertain.
Returns a score 0-1 (1 = high suspicion).
"""
# Boundary probing signature: many queries with near-uniform probability
max_prob = np.max(output_probs)
margin = max_prob - np.partition(output_probs, -2)[-2] if len(output_probs) > 1 else max_prob
profile = self.profiles[api_key]
if len(profile.queries) < 50:
return 0.0
# Track how many recent queries had low decision margin
boundary_count = sum(
1 for q in profile.queries[-100:]
if q.get("_margin", 1.0) < self.boundary_threshold
)
return boundary_count / min(len(profile.queries), 100)
def check_systematic_coverage(self, api_key: str) -> float:
"""
Detect systematic input space coverage.
Normal users cluster around specific use cases.
Extraction attacks spread uniformly across the input space.
Returns a score 0-1 (1 = high suspicion).
"""
profile = self.profiles[api_key]
if len(profile.input_embeddings) < 100:
return 0.0
embeddings = np.array(profile.input_embeddings[-500:])
# Compute pairwise cosine similarity
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
normalised = embeddings / (norms + 1e-8)
similarity_matrix = normalised @ normalised.T
# Normal usage: high average similarity (clustered)
# Extraction: low average similarity (spread out)
avg_similarity = (similarity_matrix.sum() - len(embeddings)) / (len(embeddings) * (len(embeddings) - 1))
# Threshold: avg similarity below 0.3 is suspicious
return max(0.0, 1.0 - avg_similarity / 0.3)
def get_risk_score(self, api_key: str, output_probs: np.ndarray) -> dict:
boundary_score = self.check_boundary_probing(api_key, output_probs)
coverage_score = self.check_systematic_coverage(api_key)
profile = self.profiles[api_key]
# Volume score: queries per hour
volume = len(profile.queries)
volume_score = min(1.0, volume / 1000) # 1000 queries/hour = max score
composite = 0.3 * volume_score + 0.4 * boundary_score + 0.3 * coverage_score
return {
"api_key": api_key,
"risk_score": round(composite, 3),
"volume_score": round(volume_score, 3),
"boundary_probing_score": round(boundary_score, 3),
"coverage_score": round(coverage_score, 3),
"queries_in_window": volume,
"action": "block" if composite > 0.7 else "throttle" if composite > 0.4 else "allow",
}
Rate Limiting with Adaptive Thresholds
# kong-rate-limiting.yaml - adaptive rate limiting for ML inference endpoints
apiVersion: configuration.konghq.com/v1
kind: KongPlugin
metadata:
name: ml-inference-rate-limit
namespace: ai-services
plugin: rate-limiting-advanced
config:
# Base rate limits
limit:
- 100 # requests per window
window_size:
- 3600 # 1 hour
window_type: sliding
strategy: cluster
sync_rate: 10
namespace: ml-inference
# Per-consumer limits (identified by API key)
consumer_groups:
- name: free-tier
limit: [50]
window_size: [3600]
- name: standard
limit: [500]
window_size: [3600]
- name: enterprise
limit: [5000]
window_size: [3600]
# Return remaining quota in headers
hide_client_headers: false
Output Perturbation
# output_perturbation.py - add controlled noise to model outputs
# This makes extraction harder without significantly affecting utility
import numpy as np
from typing import Optional
class OutputPerturbator:
"""
Add calibrated noise to model outputs to degrade extraction quality.
The noise is small enough that top-1 predictions are unchanged
but large enough that probability distributions cannot be used
to train a high-fidelity surrogate.
"""
def __init__(self, noise_scale: float = 0.05, top_k: Optional[int] = 5):
self.noise_scale = noise_scale
self.top_k = top_k
def perturb_probabilities(self, probs: np.ndarray) -> np.ndarray:
"""Add Laplace noise to probability distribution."""
noise = np.random.laplace(0, self.noise_scale, size=probs.shape)
perturbed = probs + noise
# Re-normalise to valid probability distribution
perturbed = np.clip(perturbed, 0, 1)
perturbed = perturbed / perturbed.sum()
return perturbed
def truncate_output(self, probs: np.ndarray, labels: list) -> dict:
"""Return only top-k predictions instead of full distribution."""
if self.top_k is None:
return {labels[i]: float(probs[i]) for i in range(len(labels))}
top_indices = np.argsort(probs)[-self.top_k:][::-1]
return {labels[i]: float(probs[i]) for i in top_indices}
def process(self, probs: np.ndarray, labels: list) -> dict:
perturbed = self.perturb_probabilities(probs)
return self.truncate_output(perturbed, labels)
Monitoring for Extraction Attempts
# prometheus-extraction-detection.yaml
groups:
- name: model-extraction
interval: 1m
rules:
# Track query volume per API key
- record: inference:queries:per_key_1h
expr: >
sum by (api_key) (
increase(inference_requests_total[1h])
)
# Alert on high query volume
- alert: HighInferenceVolume
expr: inference:queries:per_key_1h > 500
for: 5m
labels:
severity: warning
annotations:
summary: "API key {{ $labels.api_key }} made {{ $value }} queries in 1h"
description: "Investigate for potential model extraction. Normal usage is under 200 queries/hour."
# Alert on extraction risk score
- alert: ModelExtractionRiskHigh
expr: model_extraction_risk_score > 0.7
for: 5m
labels:
severity: critical
annotations:
summary: "High model extraction risk for API key {{ $labels.api_key }}"
description: >
Risk score: {{ $value }}. Boundary probing and systematic coverage
indicators elevated. Consider blocking this API key.
# Alert on requests for full probability distributions
- alert: FullDistributionRequests
expr: >
rate(inference_full_distribution_requests_total[1h])
/ rate(inference_requests_total[1h]) > 0.5
for: 30m
labels:
severity: warning
annotations:
summary: "{{ $labels.api_key }} requesting full distributions in >50% of queries"
Watermarking Model Outputs
# output_watermark.py - embed imperceptible watermarks in model outputs
# Used to prove ownership if a stolen model is discovered
import hashlib
import numpy as np
class OutputWatermarker:
"""
Embed a statistical watermark in model outputs.
The watermark is imperceptible in individual predictions
but detectable across a corpus of outputs.
"""
def __init__(self, secret_key: str, watermark_strength: float = 0.01):
self.secret_key = secret_key.encode()
self.strength = watermark_strength
def _get_watermark_signal(self, input_hash: str, output_dim: int) -> np.ndarray:
"""Generate a deterministic watermark signal from the input."""
seed = int(hashlib.sha256(self.secret_key + input_hash.encode()).hexdigest()[:8], 16)
rng = np.random.RandomState(seed)
return rng.randn(output_dim) * self.strength
def apply(self, probs: np.ndarray, input_text: str) -> np.ndarray:
"""Apply watermark to output probabilities."""
input_hash = hashlib.sha256(input_text.encode()).hexdigest()
signal = self._get_watermark_signal(input_hash, len(probs))
watermarked = probs + signal
watermarked = np.clip(watermarked, 0, 1)
watermarked = watermarked / watermarked.sum()
return watermarked
def detect(self, collected_outputs: list, collected_inputs: list) -> dict:
"""
Detect watermark presence in a collection of outputs.
Requires access to multiple input-output pairs from the suspected copy.
"""
correlations = []
for inp, out in zip(collected_inputs, collected_outputs):
input_hash = hashlib.sha256(inp.encode()).hexdigest()
expected_signal = self._get_watermark_signal(input_hash, len(out))
correlation = np.corrcoef(out - out.mean(), expected_signal)[0, 1]
correlations.append(correlation)
avg_correlation = np.mean(correlations)
return {
"watermark_detected": avg_correlation > 0.3,
"confidence": min(1.0, avg_correlation / 0.5),
"samples_tested": len(correlations),
"avg_correlation": float(avg_correlation),
}
Expected Behaviour
- Queries analysed in real time for extraction signatures (boundary probing, systematic coverage)
- Risk scores computed per API key and updated with each query
- API keys with risk score above 0.7 are automatically blocked pending review
- Output probabilities perturbed with calibrated noise (top-1 accuracy preserved, surrogate training degraded)
- Full probability distributions restricted to top-k unless enterprise tier
- Watermarks embedded in all outputs for post-hoc ownership verification
- Alerts fire within 5 minutes of sustained suspicious query patterns
Trade-offs
| Control | Impact | Risk | Mitigation |
|---|---|---|---|
| Output perturbation | Adds noise to probability distributions | Downstream applications relying on exact probabilities may degrade | Tune noise scale per use case. Offer exact probabilities only to trusted enterprise customers. |
| Top-k truncation | Returns only top-k predictions | Users who need full distributions for calibration lose functionality | Provide full distributions through a separate, audited endpoint with additional authentication. |
| Boundary probing detection | Flags queries near decision boundaries | Legitimate active learning workflows probe boundaries intentionally | Allowlist known active learning pipelines. Review flagged API keys before blocking. |
| Rate limiting | Caps queries per time window | Legitimate high-volume users hit limits | Tier-based limits. Enterprise customers get higher limits with contractual anti-extraction clauses. |
Failure Modes
| Failure | Symptom | Detection | Recovery |
|---|---|---|---|
| Low-and-slow extraction | Attacker stays below rate limits and detection thresholds | Model copy appears externally; detection scores stayed low | Lower detection thresholds. Add longer-window analysis (weekly, monthly). |
| False positive on power user | Legitimate high-volume user blocked | User reports access issues; support tickets | Review query patterns manually. Allowlist after verification. Adjust detection parameters. |
| Watermark removed by retraining | Attacker retrains surrogate, removing watermark | Watermark detection returns negative on suspected copy | Use multiple independent watermarking schemes. Combine with fingerprinting (unique responses to canary inputs). |
| Perturbation too aggressive | Top-1 predictions occasionally change | User reports inconsistent results; accuracy metrics drop | Reduce noise scale. Validate that top-1 accuracy is preserved on a test set before deploying noise parameters. |
When to Consider a Managed Alternative
Model extraction defence requires ongoing monitoring, threshold tuning, and response to evolving attack techniques. The detection system itself needs regular updating.
- Cloudflare (#29) AI Gateway: Managed API gateway with rate limiting, request logging, and abuse detection for ML inference endpoints.
- Kong (#35): API gateway with advanced rate limiting plugins, consumer grouping, and request analytics for detecting anomalous usage patterns.
Premium content pack: Model extraction defence pack. Query pattern detector (Python), output perturbation library, watermarking toolkit, Kong rate limiting configurations, Prometheus alerting rules, and extraction attack simulation scripts for testing defences.