34
34
flags .DEFINE_bool ("enable_jax_profiler" , False , "enable jax profiler" )
35
35
flags .DEFINE_integer ("jax_profiler_port" , 9999 , "port of JAX profiler server" )
36
36
flags .DEFINE_integer ("num_hosts" , 4 , "Number of TPU host" , required = False )
37
+ flags .DEFINE_integer (
38
+ "worker_chips" , 4 , "Number of TPU chips per worker" , required = False
39
+ )
37
40
38
41
39
42
def create_head_resource_name (generation , tpu_chips ):
40
- # TODO: Make this work for special cases like v5e-8
41
- num_cores = 2 * tpu_chips
42
-
43
- return f"TPU-{ generation } -{ num_cores } -head"
43
+ return f"TPU-{ generation } -{ tpu_chips } -head"
44
44
45
45
46
46
def create_engine (** kwargs ):
@@ -60,6 +60,9 @@ def create_engine(**kwargs):
60
60
quantize_kv = kwargs ["quantize_kv" ],
61
61
max_cache_length = kwargs ["max_cache_length" ],
62
62
sharding_config = kwargs ["sharding_config" ],
63
+ num_hosts = kwargs ["num_hosts" ],
64
+ worker_chips = kwargs ["worker_chips" ],
65
+ tpu_chips = kwargs ["tpu_chips" ],
63
66
enable_jax_profiler = kwargs ["enable_jax_profiler" ],
64
67
jax_profiler_port = kwargs ["jax_profiler_port" ],
65
68
)
@@ -124,6 +127,8 @@ def main(_argv):
124
127
ray_actor_options = {"resources" : {resource_name : 1 }}
125
128
).bind (
126
129
tpu_chips = FLAGS .tpu_chips ,
130
+ worker_chips = FLAGS .worker_chips ,
131
+ num_hosts = FLAGS .num_hosts ,
127
132
model_name = FLAGS .model_name ,
128
133
tokenizer_path = FLAGS .tokenizer_path ,
129
134
ckpt_path = FLAGS .checkpoint_path ,
0 commit comments