inference-optimization
SKILL.md
Inference Optimization Skill
Making AI inference faster and cheaper.
Performance Metrics
@dataclass
class InferenceMetrics:
ttft: float # Time to First Token (seconds)
tpot: float # Time Per Output Token
throughput: float # Tokens/second
latency: float # Total time
Model Optimization
Quantization
# 8-bit
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
device_map="auto"
)
# 4-bit
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
# GPTQ (better 4-bit)
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
"TheBloke/Llama-2-7B-GPTQ"
)
# AWQ (best for inference)
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(
"TheBloke/Llama-2-7B-AWQ",
fuse_layers=True
)
Speculative Decoding
def speculative_decode(target, draft, prompt, k=4):
"""Small model drafts, large model verifies."""
input_ids = tokenize(prompt)
while not complete(input_ids):
# Draft k tokens
draft_ids = draft.generate(input_ids, max_new_tokens=k)
# Verify with target (single forward!)
logits = target(draft_ids).logits
# Accept matching
accepted = verify_and_accept(draft_ids, logits)
input_ids = torch.cat([input_ids, accepted], dim=-1)
return decode(input_ids)
Service Optimization
KV Cache (vLLM)
from vllm import LLM
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
gpu_memory_utilization=0.9,
max_model_len=4096,
enable_prefix_caching=True # Reuse common prefixes
)
Batching
# Continuous batching (vLLM, TGI)
# Dynamic add/remove requests
# Dynamic batching
class DynamicBatcher:
def __init__(self, max_batch=8, max_wait_ms=100):
self.queue = []
self.max_batch = max_batch
self.max_wait = max_wait_ms
async def add(self, request):
future = asyncio.Future()
self.queue.append((request, future))
if len(self.queue) >= self.max_batch:
await self.process_batch()
return await future
Caching
Exact Cache
class PromptCache:
def get_or_generate(self, prompt, model):
key = hash(prompt)
cached = self.redis.get(key)
if cached:
return json.loads(cached)
response = model.generate(prompt)
self.redis.setex(key, 3600, json.dumps(response))
return response
Semantic Cache
class SemanticCache:
def get_or_generate(self, prompt, model, threshold=0.95):
emb = self.embed(prompt)
for cached, cached_emb in self.embeddings.items():
if cosine_similarity(emb, cached_emb) > threshold:
return self.responses[cached]
response = model.generate(prompt)
self.embeddings[prompt] = emb
self.responses[prompt] = response
return response
Best Practices
- Start with quantization (easy win)
- Use vLLM/TGI for serving
- Enable prefix caching
- Add semantic caching for common queries
- Monitor TTFT and throughput
Weekly Installs
1
Repository
doanchienthangdev/omgkitGitHub Stars
3
First Seen
6 days ago
Security Audits
Installed on
zencoder1
amp1
cline1
openclaw1
opencode1
cursor1