Skip to content

Commit 36f8a23

Browse files
authored
Handle v5e-8 in run_ray_serve_interleave (#162)
* update ray; add v5 configs * Handle worker_chips in run_ray_serve_interleave.py * format
1 parent ed1e853 commit 36f8a23

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

run_ray_serve_interleave.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
3535
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
3636
flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False)
37+
flags.DEFINE_integer(
38+
"worker_chips", 4, "Number of TPU chips per worker", required=False
39+
)
3740

3841

3942
def create_head_resource_name(generation, tpu_chips):
40-
# TODO: Make this work for special cases like v5e-8
41-
num_cores = 2 * tpu_chips
42-
43-
return f"TPU-{generation}-{num_cores}-head"
43+
return f"TPU-{generation}-{tpu_chips}-head"
4444

4545

4646
def create_engine(**kwargs):
@@ -60,6 +60,9 @@ def create_engine(**kwargs):
6060
quantize_kv=kwargs["quantize_kv"],
6161
max_cache_length=kwargs["max_cache_length"],
6262
sharding_config=kwargs["sharding_config"],
63+
num_hosts=kwargs["num_hosts"],
64+
worker_chips=kwargs["worker_chips"],
65+
tpu_chips=kwargs["tpu_chips"],
6366
enable_jax_profiler=kwargs["enable_jax_profiler"],
6467
jax_profiler_port=kwargs["jax_profiler_port"],
6568
)
@@ -124,6 +127,8 @@ def main(_argv):
124127
ray_actor_options={"resources": {resource_name: 1}}
125128
).bind(
126129
tpu_chips=FLAGS.tpu_chips,
130+
worker_chips=FLAGS.worker_chips,
131+
num_hosts=FLAGS.num_hosts,
127132
model_name=FLAGS.model_name,
128133
tokenizer_path=FLAGS.tokenizer_path,
129134
ckpt_path=FLAGS.checkpoint_path,

0 commit comments

Comments
 (0)