-
Notifications
You must be signed in to change notification settings - Fork 18
Description
Prefill_ray() now returns a [result, first_token]
tuple, where first_token
contains a Jax array. This will cause a crash when attempting to fetch the Ray results remotely:
job_id:06000000
:actor_name:ServeReplica:default:JetStreamDeployment
SIGTERM handler is not set because current thread is not the main thread.
Using address example-cluster-kuberay-head-svc.default.svc.cluster.local:6379 set in the environment variable RAY_ADDRESS
Connecting to existing Ray cluster at address: example-cluster-kuberay-head-svc.default.svc.cluster.local:6379...
Calling ray.init() again after it has already been called.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Traceback (most recent call last):
File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 162, in run
super().run()
File "/home/ray/anaconda3/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 507, in _prefill_thread
prefill_result, first_token = prefill_engine.prefill(
File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 83, in prefill
return self.prefill_impl(
File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 113, in prefill_impl
results = ray.get(all_outputs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2623, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 861, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): �[36mray::PyTorchRayWorker.prefill_ray()�[39m (pid=14601, ip=10.104.7.5, actor_id=0721a490262f0d248878f59d06000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7974fc14e410>)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
fun, args, arr_state = self._value.__reduce__()
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
Metadata
Metadata
Assignees
Labels
No labels