Skip to content

Commit a1b9cd9

Browse files
committed
Add Dockerfile and entrypoint script for the jetstream-pytorch-server image
1 parent a0449f7 commit a1b9cd9

File tree

4 files changed

+89
-1
lines changed

4 files changed

+89
-1
lines changed
+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Ubuntu:22.04
16+
# Use Ubuntu 22.04 from Docker Hub.
17+
# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
18+
FROM ubuntu:22.04
19+
20+
ENV DEBIAN_FRONTEND=noninteractive
21+
ENV PYTORCH_JETSTREAM_VERSION=main
22+
23+
RUN apt -y update && apt install -y --no-install-recommends \
24+
ca-certificates \
25+
git \
26+
python3.10 \
27+
python3-pip
28+
29+
RUN python3 -m pip install --upgrade pip
30+
31+
RUN update-alternatives --install \
32+
/usr/bin/python3 python3 /usr/bin/python3.10 1
33+
34+
35+
RUN git clone https://github.com/AI-Hypercomputer/JetStream.git
36+
RUN git clone https://github.com/AI-Hypercomputer/jetstream-pytorch.git && \
37+
cd /jetstream-pytorch && \
38+
git checkout ${PYTORCH_JETSTREAM_VERSION} && \
39+
bash install_everything.sh
40+
41+
RUN pip install -U jax[tpu]==0.4.35 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
42+
43+
RUN cd /JetStream && \
44+
pip install -e .
45+
46+
RUN pip install huggingface_hub[cli]
47+
48+
COPY jetstream_pytorch_server_entrypoint.sh /usr/bin/
49+
50+
RUN chmod +x /usr/bin/jetstream_pytorch_server_entrypoint.sh
51+
52+
ENTRYPOINT ["/usr/bin/jetstream_pytorch_server_entrypoint.sh"]
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## Build and upload JetStream PyTorch Server image
2+
3+
These instructions are to build the JetStream PyTorch Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the JetStream-PyTorch framework.
4+
5+
```
6+
docker build -t jetstream-pytorch-server .
7+
docker tag jetstream-pytorch-server us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest
8+
docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest
9+
```
10+
11+
If you would like to change the version of MaxText the image is built off of, change the `PYTORCH_JETSTREAM_VERSION` environment variable:
12+
```
13+
ENV PYTORCH_JETSTREAM_VERSION=<your desired commit hash, release, or tag>
14+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
#!/bin/bash
16+
export HUGGINGFACE_TOKEN_DIR="/huggingface"
17+
18+
cd /jetstream-pytorch
19+
huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN)
20+
jpt serve $@

jetstream_pt/cli.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def serve():
109109
if 1 <= FLAGS.prometheus_port <= 65535:
110110
metrics_server_config = MetricsServerConfig(port=FLAGS.prometheus_port)
111111
else:
112-
raise ValueError(f"Invalid port number: {FLAGS.prometheus_port}. Port must be between 1 and 65535.")
112+
raise ValueError(
113+
f"Invalid port number: {FLAGS.prometheus_port}. Port must be between 1 and 65535."
114+
)
113115

114116
# We separate credential from run so that we can unit test it with local credentials.
115117
# We would like to add grpc credentials for OSS.

0 commit comments

Comments
 (0)