diff --git a/run_ray_serve_interleave.py b/run_ray_serve_interleave.py index 6d4edb5..3c0d5ed 100644 --- a/run_ray_serve_interleave.py +++ b/run_ray_serve_interleave.py @@ -40,7 +40,11 @@ def create_head_resource_name(generation, tpu_chips): - return f"TPU-{generation}-{tpu_chips}-head" + if generation == "v5litepod": + return f"TPU-{generation}-{tpu_chips}-head" + else: + tpu_cores = tpu_chips * 2 + return f"TPU-{generation}-{tpu_cores}-head" def create_engine(**kwargs):