Membership Inference Defence: Preventing Attackers from Determining Training Data Inclusion

Membership Inference Defence: Preventing Attackers from Determining Training Data Inclusion

Problem

Membership inference attacks determine whether a specific data record was used to train a model. An attacker queries the model with a candidate record and analyses the output confidence to distinguish between training members (records the model memorised) and non-members. Models are more confident on data they have seen during training, and this signal leaks membership information.

This matters for privacy. If an attacker can confirm that a patient’s medical record was in the training data, that reveals the patient visited a specific hospital. If they can confirm a financial transaction was in the training set, that reveals the transaction occurred. Membership inference turns every deployed model into a potential privacy leak.

The attack requires only black-box access to the model’s API. The attacker does not need access to model weights, training code, or infrastructure. They need only query the model and observe output probabilities.

Threat Model

  • Adversary: Any party with API access to the model. Could be a user, a competitor, a data subject exercising privacy rights, or a regulator testing compliance.
  • Objective: Determine whether a specific record was in the training data. Enumerate which individuals appear in the training set. Demonstrate that a model was trained on data without consent (regulatory action).
  • Blast radius: Privacy violation for affected individuals. Regulatory exposure under GDPR (right to erasure verification), HIPAA (medical data), or CCPA. Reputational damage if training data sourcing was questionable.

Configuration

Differential Privacy During Training

# dp_training.py - train with differential privacy using Opacus (PyTorch)
import torch
from torch.utils.data import DataLoader
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator

def train_with_dp(
    model: torch.nn.Module,
    train_loader: DataLoader,
    epochs: int = 10,
    target_epsilon: float = 8.0,
    target_delta: float = 1e-5,
    max_grad_norm: float = 1.0,
    learning_rate: float = 0.001,
):
    """
    Train a model with differential privacy guarantees.

    Args:
        target_epsilon: Privacy budget. Lower = more private, less accurate.
                        Recommended: 1-10 for strong privacy, 10-50 for moderate.
        target_delta: Probability of privacy guarantee failure. Should be < 1/N.
        max_grad_norm: Per-sample gradient clipping norm.
    """
    # Validate model is compatible with DP (no batch norm, etc.)
    model = ModuleValidator.fix(model)
    errors = ModuleValidator.validate(model, strict=False)
    if errors:
        raise ValueError(f"Model incompatible with DP training: {errors}")

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Attach the privacy engine
    privacy_engine = PrivacyEngine()
    model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
        module=model,
        optimizer=optimizer,
        data_loader=train_loader,
        epochs=epochs,
        target_epsilon=target_epsilon,
        target_delta=target_delta,
        max_grad_norm=max_grad_norm,
    )

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            inputs, labels = batch
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        epsilon = privacy_engine.get_epsilon(delta=target_delta)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Epsilon: {epsilon:.2f}")

    final_epsilon = privacy_engine.get_epsilon(delta=target_delta)
    print(f"Training complete. Final (epsilon, delta): ({final_epsilon:.2f}, {target_delta})")
    return model, final_epsilon

Output Calibration and Confidence Restriction

# confidence_restrictor.py - restrict output confidence to reduce membership signal
import numpy as np
from typing import Optional

class ConfidenceRestrictor:
    """
    Reduce the precision of model output confidences.
    Membership inference exploits the gap between confidence on members
    vs non-members. Restricting confidence reduces this signal.
    """

    def __init__(
        self,
        temperature: float = 2.0,
        precision: int = 2,
        max_confidence: Optional[float] = 0.95,
    ):
        self.temperature = temperature
        self.precision = precision
        self.max_confidence = max_confidence

    def apply_temperature(self, logits: np.ndarray) -> np.ndarray:
        """Apply temperature scaling to reduce confidence sharpness."""
        scaled = logits / self.temperature
        exp_scaled = np.exp(scaled - np.max(scaled))
        return exp_scaled / exp_scaled.sum()

    def round_probabilities(self, probs: np.ndarray) -> np.ndarray:
        """Round probabilities to reduce precision."""
        rounded = np.round(probs, self.precision)
        rounded = rounded / rounded.sum()  # re-normalise
        return rounded

    def cap_confidence(self, probs: np.ndarray) -> np.ndarray:
        """Cap maximum confidence to prevent overconfident outputs."""
        if self.max_confidence is None:
            return probs
        capped = np.minimum(probs, self.max_confidence)
        capped = capped / capped.sum()
        return capped

    def process(self, logits: np.ndarray) -> np.ndarray:
        """Full pipeline: temperature -> cap -> round."""
        probs = self.apply_temperature(logits)
        probs = self.cap_confidence(probs)
        probs = self.round_probabilities(probs)
        return probs

Monitoring for Membership Inference Query Patterns

# mi_detector.py - detect membership inference attack patterns
import time
from collections import defaultdict
from typing import List

class MembershipInferenceDetector:
    """
    Detect query patterns consistent with membership inference attacks.
    Key signals:
    - Same record queried repeatedly (statistical averaging)
    - Queries for records with slight perturbations (shadow model training)
    - Requests consistently include full probability distributions
    - Query distribution matches known dataset structure
    """

    def __init__(self, window_seconds: int = 3600):
        self.window = window_seconds
        self.query_hashes = defaultdict(list)  # api_key -> [(hash, timestamp)]
        self.repeat_counts = defaultdict(lambda: defaultdict(int))

    def record_query(self, api_key: str, input_hash: str):
        now = time.time()
        self.query_hashes[api_key].append((input_hash, now))
        self.repeat_counts[api_key][input_hash] += 1

        # Prune old entries
        cutoff = now - self.window
        self.query_hashes[api_key] = [
            (h, t) for h, t in self.query_hashes[api_key] if t > cutoff
        ]

    def check_repeated_queries(self, api_key: str) -> dict:
        """Detect if the same inputs are queried repeatedly."""
        repeats = {h: c for h, c in self.repeat_counts[api_key].items() if c > 3}
        return {
            "repeated_inputs": len(repeats),
            "max_repeats": max(repeats.values()) if repeats else 0,
            "suspicious": len(repeats) > 10,
        }

    def check_perturbation_pattern(self, api_key: str, input_hashes: List[str]) -> dict:
        """
        Detect if queries form perturbation clusters.
        Membership inference often queries record X, then X with small changes.
        """
        # Simplified: check for high diversity of unique inputs
        unique = len(set(input_hashes))
        total = len(input_hashes)
        diversity = unique / total if total > 0 else 0
        return {
            "unique_inputs": unique,
            "total_queries": total,
            "diversity_ratio": round(diversity, 3),
            "suspicious": diversity > 0.9 and total > 100,
        }

    def get_risk_assessment(self, api_key: str) -> dict:
        hashes = [h for h, _ in self.query_hashes[api_key]]
        repeat_check = self.check_repeated_queries(api_key)
        perturbation_check = self.check_perturbation_pattern(api_key, hashes)

        risk_score = 0.0
        if repeat_check["suspicious"]:
            risk_score += 0.5
        if perturbation_check["suspicious"]:
            risk_score += 0.5

        return {
            "api_key": api_key,
            "risk_score": risk_score,
            "repeat_analysis": repeat_check,
            "perturbation_analysis": perturbation_check,
            "action": "block" if risk_score > 0.7 else "monitor" if risk_score > 0.3 else "allow",
        }

Prometheus Alerting for Membership Inference

# prometheus-membership-inference.yaml
groups:
  - name: membership-inference
    interval: 1m
    rules:
      # Track repeated queries for the same input
      - alert: RepeatedInputQueries
        expr: >
          max by (api_key) (
            inference_input_repeat_count
          ) > 10
        for: 10m
        labels:
          severity: warning
        annotations:
          summary: "API key {{ $labels.api_key }} queried the same input {{ $value }} times"
          description: >
            Repeated querying of identical inputs is a signature of membership inference
            attacks (statistical averaging to reduce noise). Investigate this API key.

      # Alert on high query diversity (systematic probing)
      - alert: HighQueryDiversity
        expr: >
          inference_unique_inputs_ratio > 0.95
          and
          increase(inference_requests_total[1h]) > 200
        for: 30m
        labels:
          severity: warning
        annotations:
          summary: "API key {{ $labels.api_key }} has >95% unique inputs across {{ $value }} queries"
          description: >
            High input diversity combined with high volume suggests systematic dataset
            probing rather than normal application usage.

      # Alert on membership inference risk score
      - alert: MembershipInferenceRiskHigh
        expr: membership_inference_risk_score > 0.7
        for: 5m
        labels:
          severity: critical
        annotations:
          summary: "High membership inference risk for {{ $labels.api_key }}"

Expected Behaviour

  • Models trained with differential privacy (epsilon 1-10) resist membership inference with minimal accuracy loss
  • Output confidences are temperature-scaled, capped, and rounded before returning to API consumers
  • Repeated queries for identical inputs are detected and flagged within 10 minutes
  • High query diversity patterns trigger monitoring alerts within 30 minutes
  • Risk scores above 0.7 result in API key suspension pending review
  • Privacy guarantees are mathematically provable via the epsilon-delta framework

Trade-offs

Control Impact Risk Mitigation
Differential privacy (epsilon=8) 2-5% accuracy reduction depending on model and dataset Model performance may be unacceptable for high-stakes applications Tune epsilon per use case. Use larger training datasets to offset accuracy loss.
Temperature scaling (T=2.0) Flattens probability distributions Applications relying on calibrated probabilities lose calibration Provide calibrated outputs only through authenticated, audited endpoints.
Confidence capping (max 0.95) Reduces maximum reported confidence Users cannot distinguish between 95% and 99.9% confident predictions Accept this as a privacy cost. Document the cap in API documentation.
Repeated query detection Flags legitimate retries and idempotent requests False positives on applications with retry logic Allowlist known retry patterns. Exempt health check endpoints.

Failure Modes

Failure Symptom Detection Recovery
Epsilon too large Membership inference still succeeds despite DP training Red team audit demonstrates >60% membership inference accuracy Retrain with lower epsilon (stronger privacy). Accept accuracy trade-off.
Temperature too high Model outputs are uniformly uncertain, useless for applications User reports that predictions are never confident; utility metrics drop Reduce temperature. Find the minimum temperature that blocks inference while preserving utility.
Detector false positive Legitimate application blocked for “suspicious” queries Support tickets from blocked users; business impact Review blocked API keys within 1 hour SLA. Tune detection thresholds.
Adaptive attacker Attacker varies query timing and distribution to avoid detection Membership inference succeeds; detection scores stay low Add longer-window analysis. Combine multiple detection signals. Deploy honeypot records.

When to Consider a Managed Alternative

Differential privacy training and membership inference monitoring require ML privacy expertise that most teams lack. The trade-off tuning between privacy and utility is application-specific and ongoing.

  • Grafana Cloud (#108): Long-term metric storage for query pattern analysis. ML-powered anomaly detection on API usage patterns.
  • Cloudflare (#29) AI Gateway: Managed API gateway with abuse detection, rate limiting, and request analytics for inference endpoints.

Premium content pack: Membership inference defence pack. Opacus training scripts with tuned hyperparameters, confidence restriction middleware (Python), membership inference detection service, Prometheus alerting rules, and a membership inference red team test suite for validating defences.