x-algo-ml

SKILL.md

X Algorithm ML Architecture

The X recommendation system uses Phoenix, a transformer-based ML system for predicting user engagement. It operates in two stages: retrieval and ranking.

Two-Stage Pipeline

┌─────────────────────────────────────────────────────────────────────────────────┐
│                           RECOMMENDATION PIPELINE                               │
├─────────────────────────────────────────────────────────────────────────────────┤
│                                                                                 │
│   ┌──────────┐     ┌─────────────────────┐     ┌─────────────────────┐          │
│   │          │     │                     │     │                     │          │
│   │   User   │────▶│   STAGE 1:          │────▶│   STAGE 2:          │────▶ Feed│
│   │ Request  │     │   RETRIEVAL         │     │   RANKING           │          │
│   │          │     │   (Two-Tower)       │     │   (Transformer)     │          │
│   └──────────┘     │                     │     │                     │          │
│                    │   Millions → 1000s  │     │   1000s → Ranked    │          │
│                    └─────────────────────┘     └─────────────────────┘          │
│                                                                                 │
└─────────────────────────────────────────────────────────────────────────────────┘

Stage 1: Retrieval (Two-Tower Model)

Efficiently narrows millions of candidates to thousands using approximate nearest neighbor search.

Architecture

  1. User Tower: Encodes user features + engagement history → normalized embedding [B, D]
  2. Candidate Tower: Pre-computed embeddings for all posts in corpus → [N, D]
  3. Similarity: Dot product between user embedding and candidate embeddings
User Tower          Candidate Tower
    │                     │
    ▼                     ▼
[B, D] user emb    [N, D] all posts
    │                     │
    └───── dot product ───┘
         Top-K by similarity

Stage 2: Ranking (Transformer with Candidate Isolation)

Scores the retrieved candidates using a transformer that predicts multiple engagement actions.

Model Configuration

# phoenix/recsys_model.py
@dataclass
class PhoenixModelConfig:
    model: TransformerConfig              # Grok-1 based transformer
    emb_size: int                         # Embedding dimension D
    num_actions: int                      # 18 action types
    history_seq_len: int = 128            # User history length
    candidate_seq_len: int = 32           # Candidates per batch
    product_surface_vocab_size: int = 16  # Where post was seen
    hash_config: HashConfig               # Hash embedding config

Input Structure

class RecsysBatch(NamedTuple):
    # User identification
    user_hashes: ArrayLike               # [B, num_user_hashes]

    # User engagement history
    history_post_hashes: ArrayLike       # [B, S, num_item_hashes]
    history_author_hashes: ArrayLike     # [B, S, num_author_hashes]
    history_actions: ArrayLike           # [B, S, num_actions]
    history_product_surface: ArrayLike   # [B, S]

    # Candidates to score
    candidate_post_hashes: ArrayLike     # [B, C, num_item_hashes]
    candidate_author_hashes: ArrayLike   # [B, C, num_author_hashes]
    candidate_product_surface: ArrayLike # [B, C]

Hash-Based Embeddings

Multiple hash functions map IDs to embedding tables:

@dataclass
class HashConfig:
    num_user_hashes: int = 2      # Hash user ID 2 ways
    num_item_hashes: int = 2      # Hash post ID 2 ways
    num_author_hashes: int = 2    # Hash author ID 2 ways

Why hashes?

  • Fixed memory: No need for individual embeddings per user/post
  • Handles new entities: Any ID maps to some embedding
  • Collision averaging: Multiple hashes reduce collision impact

Embedding Combination

Each entity type has a "reduce" function that combines hash embeddings:

# User: Concatenate hash embeddings → project to D
def block_user_reduce(...):
    # [B, num_user_hashes, D] → [B, 1, num_user_hashes * D] → [B, 1, D]
    user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * D))
    user_embedding = jnp.dot(user_embedding, proj_mat_1)  # Project down
    return user_embedding, user_padding_mask

# History: Combine post + author + actions + product_surface
def block_history_reduce(...):
    # Concatenate all features, project to D
    post_author_embedding = jnp.concatenate([
        history_post_embeddings_reshaped,
        history_author_embeddings_reshaped,
        history_actions_embeddings,
        history_product_surface_embeddings,
    ], axis=-1)
    history_embedding = jnp.dot(post_author_embedding, proj_mat_3)
    return history_embedding, history_padding_mask

Transformer Input

Final input is concatenation of:

[User (1)] + [History (S)] + [Candidates (C)]
     │              │               │
     ▼              ▼               ▼
 [B, 1, D]    [B, S, D]       [B, C, D]
         ╲       │       ╱
          ╲      │      ╱
            [B, 1+S+C, D]

Attention Masking: Candidate Isolation

Critical design: Candidates cannot attend to each other, only to user + history.

                    ATTENTION MASK
         Keys (what we attend TO)
         ─────────────────────────────────────────────▶

         │ User │    History (S)    │   Candidates (C)    │
    ┌────┼──────┼───────────────────┼─────────────────────┤
 Q  │ U  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✗   ✗   ✗      │
 u  ├────┼──────┼───────────────────┼─────────────────────┤
 e  │ H  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✗   ✗   ✗      │
 r  │ i  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✗   ✗   ✗      │
 i  │ s  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✗   ✗   ✗      │
 e  │ t  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✗   ✗   ✗      │
 s  ├────┼──────┼───────────────────┼─────────────────────┤
    │ C  │  ✓   │  ✓   ✓   ✓   ✓   │  ✓   ✗   ✗   ✗      │
 │  │ a  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✓   ✗   ✗      │
 │  │ n  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✗   ✓   ✗      │
 ▼  │ d  │  ✓   │  ✓   ✓   ✓   ✓   │  ✗   ✗   ✗   ✓      │
    └────┴──────┴───────────────────┴─────────────────────┘

    ✓ = Can attend          ✗ = Cannot attend (diagonal only for candidates)

Why candidate isolation?

  • Score for post A shouldn't depend on whether post B is in the batch
  • Ensures consistent scoring regardless of batch composition
  • Enables parallel scoring of candidates

Transformer Forward Pass

def __call__(self, batch, recsys_embeddings) -> RecsysModelOutput:
    # 1. Build combined embeddings
    embeddings, padding_mask, candidate_start = self.build_inputs(batch, recsys_embeddings)

    # 2. Pass through transformer (with candidate isolation mask)
    model_output = self.model(
        embeddings,
        padding_mask,
        candidate_start_offset=candidate_start,  # For attention masking
    )

    # 3. Extract candidate outputs
    out_embeddings = layer_norm(model_output.embeddings)
    candidate_embeddings = out_embeddings[:, candidate_start:, :]

    # 4. Project to action logits
    logits = jnp.dot(candidate_embeddings, unembeddings)
    # Shape: [B, num_candidates, num_actions]

    return RecsysModelOutput(logits=logits)

Output: Multi-Action Prediction

Output Shape: [B, num_candidates, num_actions]
    ┌─────────────────────────────────────────────┐
    │ Like │ Reply │ Retweet │ Quote │ ... (18)   │
    └─────────────────────────────────────────────┘

Each output is a log-probability. Convert to probability:

probability = exp(log_prob)

Action Embeddings

History actions are encoded as signed vectors:

def _get_action_embeddings(self, actions):
    # actions: [B, S, num_actions] multi-hot vector
    actions_signed = (2 * actions - 1)  # 0→-1, 1→+1
    action_emb = jnp.dot(actions_signed, action_projection)
    return action_emb

This encodes "did action" (+1) vs "didn't do action" (-1) for each action type.

Product Surface Embeddings

Where the user engaged (home feed, search, notifications, etc.):

def _single_hot_to_embeddings(self, input, vocab_size, emb_size, name):
    # Standard embedding lookup table
    embedding_table = hk.get_parameter(name, [vocab_size, emb_size])
    input_one_hot = jax.nn.one_hot(input, vocab_size)
    return jnp.dot(input_one_hot, embedding_table)

Model Heritage

The sample transformer implementation is ported from the Grok-1 open source release by xAI. The core transformer architecture comes from Grok-1, adapted for recommendation system use cases with custom input embeddings and attention masking for candidate isolation.

Related Skills

  • /x-algo-engagement - The 18 action types the model predicts
  • /x-algo-scoring - How predictions become weighted scores
  • /x-algo-pipeline - Where ML fits in the full system
Weekly Installs
3
Installed on
windsurf2
trae2
opencode2
cursor2
codex2
claude-code2