From a9b4caf543ac80a622693f86561da80648be5fad Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Thu, 1 Aug 2024 04:16:43 +0000 Subject: [PATCH] Fix Ray engine crash on multihost --- jetstream_pt/ray_worker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 01b647d..8738bcd 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -466,6 +466,9 @@ def prefill_ray( logits = logits[0] token = np.argmax(logits[true_length - 1]) + updated_caches = multihost_utils.process_allgather( + updated_caches, tiled=True + ) prefix = Prefix(token, updated_caches, true_length) self.prefix_queue.put(prefix, block=False)