How to train a Million Context LLM — with Mark Huang of Gradient.ai
Overview
Long context capabilities represent a significant frontier in AI development, with Gradient extending Llama 3's context window from 8,000 to 1 million tokens through techniques like RoPE scaling and specialized attention mechanisms, enabling applications from code repositories to finance.
The technical approach to context extension involves careful consideration of positional encodings, curriculum learning (progressively increasing context length), and data quality, with implementation challenges including GPU memory bandwidth utilization and network topology optimization.
Evaluation of long-context models requires sophisticated benchmarking beyond simple "needle in a haystack" tests, including multiple retrievals, variable tracking, and summary statistics generation, with performance degradation becoming apparent at extremely large contexts (4M+ tokens).
Multimodality is emerging as the next critical frontier in AI development, with early fusion models showing promise for integrating videos, images, and text in ways that provide genuine user value rather than just technical complexity.
The AI research landscape has experienced a 10x information explosion, requiring careful filtering strategies that prioritize practical applications, with Twitter and hands-on product testing proving more valuable for staying current than traditional academic conferences.
Content
Background and Professional Journey
Mark Wang is a former quantitative finance professional who transitioned to tech
Worked as lead data scientist at Box and staff ML scientist at Splunk
Moved from finance to tech to gain more experience with big data and machine learning at scale
Notes a trend of finance professionals moving into tech and AI
Sees current AI landscape as similar to previous "trading wars" in terms of talent competition
Feels empowered by OpenAI's developments to create impactful products
Gradient - Company Overview and Formation
Gradient is a full-stack AI platform
Goal: Enable enterprises to transition from traditional RPA (Robotic Process Automation) to more autonomous, "agentic" workflows
Aims to create a horizontal platform for AI workforce transformation
Formed a team with Chris Chang (former Meta/Google/Netflix engineer)
Motivated by challenges in enterprise ML platforms, particularly frequent workflow migrations
Goal was to reduce operational friction in shipping workloads
Agent Definition and Perspective
Mark defines an agent beyond just non-deterministic execution
Focuses on marginal improvements in probability of success at each workflow stage
Acknowledges "agent" is an overloaded term in current AI landscape
Emphasizes statistical approach to measuring agent effectiveness
Core Technical Vision
Focus on developing systems that can handle "out of domain" problems
Emphasize machine learning as a continuous learning process
Desire for AI systems that grow and adapt alongside users
Viewed the project as part of broader "meta learning" workflow
Interested in adaptable AI systems that can generalize across different domains
Long Context Learning Project
Chose to extend Llama 3's context length
Motivated by existing models' short context windows (8,000 tokens)
Inspired by Google's Gemini with 1 million token context length
Viewed language models as "compression algorithms"
Worked with Crusoe (computational infrastructure provider) to facilitate the project
Recognized not everyone can easily undertake such computational challenges
Discussed GPU cloud providers and their collaboration to scale up computational resources using L40s GPU instances
Combined flash attention and ring attention for training
Ring attention is primarily about better GPU memory bandwidth utilization
Evaluated multiple implementation approaches for ring attention
Original JAX implementation was not GPU-friendly
Technical Approach to Context Length Extension
Self-attention has quadratic memory scaling, making longer context sequences computationally expensive
Ongoing research about the best approach to training long context models
Curriculum learning (progressively increasing context length) may perform better than training on maximum context length from the start
Meta research suggests incrementally increasing context length can improve model performance
Data quality is crucial when extending context length
Models need good perplexity scores before context length extension
The "theta" parameter plays a significant role in determining how far a context can be extended
Positional encodings and rope scaling are important technical mechanisms for context extension
Practical takeaway: With a 4k context model, you can potentially progressively increase context length if the model shows good initial performance
Technical Details on Model Embedding and Scaling
Focus on embedding mechanisms, particularly positional encoding techniques
Theta scaling described as an empirical method for adjusting embedding distributions
Goal is to achieve interpolation rather than extrapolation in model context
Approach was developed incrementally, starting at 256 tokens and scaling up
Most current architectures are using RoPE (Rotary Positional Embedding) scaling
Alibi is less commonly used in recent models
YARN can be used alongside RoPE scaling
Pose (a LoRa-based approach) shows some limitations in very long context scenarios
The scaling approach is empirical rather than mathematically proven
Scaling laws are observed but not guaranteed to continue consistently
Implementation Details
Discussed an open-source PyTorch implementation by John Payne for context extension
Preferred PyTorch over Jax for implementation
Adapted the implementation for their specific cluster network topology
Dataset and Training Approach
Conducted two-stage model updates:
- Initial pre-training layer using Slim Pajamas dataset
- Chat dataset layer using Ultra Chat or derivatives
Focused on dataset considerations:
- Avoiding token truncation
- Ensuring content diversity
- Using embeddings for pre-filtering
Challenges in model training:
- Difficulty in injecting truly new knowledge into large language models
- Models now trained on double-digit trillion tokens
- Challenge of maintaining existing capabilities while introducing new information
- Limited empirical research on expanding model decision boundaries
Cautious about assuming small token additions can significantly alter model knowledge
Referenced Lama 2 example where further training potentially degraded language capabilities
Emphasized the importance of maintaining model flexibility and generalizability
Advanced Training Techniques
Discussing challenges of model training, particularly avoiding overfitting to specific data types
Proposing multi-stage training with mixed data sources to prevent deviation
Suggesting potential improvements to loss functions to manage data overfitting
Using GPT-4 to rephrase and generate new training data tokens
Injecting out-of-domain, lower-probability data instances
Recognizing data pipeline creation as a potentially significant part of model development
Synthetic data generation can represent 25-50% of a dataset
Model Adaptation Techniques
Discussing LoRa (Low-Rank Adaptation) techniques for language models
Exploring model "alchemy" - mixing LoRa adapters and model merging
Comparing LoRa adaptations across different domains (language models vs. stable diffusion)
Techniques for merging machine learning models, particularly using LoRA layers
Observations that model merging can be effective for stylistic tasks but may struggle with more complex abilities
Merging techniques are seen as potentially "polluting" leaderboards by allowing strategic model combinations
Evaluation Challenges
Evaluating model performance is complex due to the high-dimensional, sparse nature of assessments
Multiple evaluations provide insights but not a complete picture
Highlighting difficulties in evaluating complex, advanced AI tasks