JetStream Engine implementation in PyTorch
- Ssh to Cloud TPU VM (using v5e-8 TPU VM) a. Create a Cloud TPU VM if you haven’t
- Download jetstream-pytorch github repo
- Clone repo and install dependencies
- Download and convert weights
- Run checkpoint converter (quantizer)
- Local run
- Run the server
- Run benchmarks
- Typical Errors
gcloud compute config-ssh
gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone"
Follow the steps in
git clone https://github.com/google/jetstream-pytorch.git
(optional) Create a virtual env using venv
or conda
and activate it.
cd jetstream-pytorch
source install_everything.sh
NOTE: the above script will export PYTHONPATH, so sourcing will make it to take effect in the current shell
Following instructions here: https://github.com/meta-llama/llama#download
After you have downloaded the weights, it will also download a tokenizer.model
file that is
the tokenizer that we will use.
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
export quantize=True #whether to quantize
python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
Set tokenizer path
export tokenizer_path=tokenizer model file path from meta-llama
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
NOTE: the --platform=tpu=8
need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
Now you can fire gRPC to it
go to the deps/JetStream folder (downloaded during install_everything.sh
)
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
Please look at deps/JetStream/benchmarks/README.md
for more information.
Fix:
- Uninstall jax and jaxlib dependencies
- Reinstall using `source install_everything.sh
Fix:
- Use smaller batch size
- Use quantization