Skip to content

Ray engine crashes on multihost when fetching Jax.array from prefill_ray #150

@richardsliu

Description

@richardsliu

Prefill_ray() now returns a [result, first_token] tuple, where first_token contains a Jax array. This will cause a crash when attempting to fetch the Ray results remotely:

job_id:06000000
:actor_name:ServeReplica:default:JetStreamDeployment
SIGTERM handler is not set because current thread is not the main thread.
Using address example-cluster-kuberay-head-svc.default.svc.cluster.local:6379 set in the environment variable RAY_ADDRESS
Connecting to existing Ray cluster at address: example-cluster-kuberay-head-svc.default.svc.cluster.local:6379...
Calling ray.init() again after it has already been called.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 162, in run
    super().run()
  File "/home/ray/anaconda3/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 507, in _prefill_thread
    prefill_result, first_token = prefill_engine.prefill(
  File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 83, in prefill
    return self.prefill_impl(
  File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 113, in prefill_impl
    results = ray.get(all_outputs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2623, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 861, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): �[36mray::PyTorchRayWorker.prefill_ray()�[39m (pid=14601, ip=10.104.7.5, actor_id=0721a490262f0d248878f59d06000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7974fc14e410>)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
    fun, args, arr_state = self._value.__reduce__()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
    raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions