Skip to content

Stacked cache for MLPerf #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1e14081
Almost working except mask, need to rebase to main to pick up the the…
wang2yn84 Jul 2, 2024
64a7d4d
Fixed the test_model_impl for llama, but test_llama_e2e is still fail…
wang2yn84 Jul 2, 2024
66e2c3b
Adds lazy_cache_update and restructure the cache flags.
wang2yn84 Jul 3, 2024
0eb11dc
Disable all the prints. Fix create engine.
wang2yn84 Jul 3, 2024
cbf9e68
Fix typos and minor errors.
wang2yn84 Jul 3, 2024
0fdd8e1
Fixes create engine.
wang2yn84 Jul 3, 2024
ed9ea97
Adds new_cache_stacked and fixes cache update.
wang2yn84 Jul 4, 2024
69aba1e
Fix cache update when new_cach_stacked is False.
wang2yn84 Jul 4, 2024
fbf2fa6
Fix the cache manager and make unit tests pass except for 1.
wang2yn84 Jul 7, 2024
7f3eb0b
Updates the exportable model to return cache.
wang2yn84 Jul 7, 2024
fbd2dbd
Removed the fori loop in cache finalize. Moves the cache.finalize() t…
wang2yn84 Jul 8, 2024
a1e6742
Try to use shard_map for cache update.
wang2yn84 Jul 8, 2024
2e3951a
Fix update single cache line in cache.finalize()
wang2yn84 Jul 8, 2024
ecf5662
Adds int8 support.
wang2yn84 Jul 8, 2024
41a07d2
Int8 left aligned lazy cache update working, performance still not go…
wang2yn84 Jul 9, 2024
bdc6d4a
Fix the stacked cache introduced in the previous couple of commits.
wang2yn84 Jul 9, 2024
5443d9b
Put original ragged attention back.
wang2yn84 Jul 10, 2024
0f0af46
Add the original ragged attention kernel.
wang2yn84 Jul 10, 2024
9d80885
Fixes the bf16/int8 cache stack.
wang2yn84 Jul 10, 2024
ce4caeb
Fix int8 stacked cache insertion in engine and finalization.
wang2yn84 Jul 10, 2024
e099066
Fixes int8 with lazy cache update.
wang2yn84 Jul 11, 2024
e71c6d6
Updates the int8 test.
wang2yn84 Jul 11, 2024
662af29
Fix the int8 ragged attention output sharding.
wang2yn84 Jul 11, 2024
6c769a4
Fix group query attention broadcasting issue.
wang2yn84 Jul 11, 2024
6ff9311
Fix shard map input issue. Variables not listed as inputs are freezed…
wang2yn84 Jul 11, 2024
20821a4
Fix the flash attention mask shape; Fix the update single cache line …
wang2yn84 Jul 12, 2024
cb293c2
Adds the kv cache test.
wang2yn84 Jul 12, 2024
29d9670
Replace quantized cache "pos" with "input_pos" to align with bf16 cac…
wang2yn84 Jul 12, 2024
fe0dc8c
Fix prefill cache insertion issue for stacked cache; Changes reduce d…
wang2yn84 Jul 13, 2024
063662d
Adds lazy cache update with generate cache stacked new cache unstacke…
wang2yn84 Jul 15, 2024
143d5a6
Fix the shard map sharding for stacked generate cache and unstacked n…
wang2yn84 Jul 15, 2024
0a386e2
Using Jax API to slicing instead of Pytorch index slicing.
wang2yn84 Jul 15, 2024
b53476d
Adds stacked cache support in ragged attention reference kernel.
wang2yn84 Jul 16, 2024
5be7d0b
Adds stacked cache support for the modified ragged kernel.
wang2yn84 Jul 16, 2024
ac0a88b
Llama2 70b int8 optimization done. Output not correct yet.
wang2yn84 Jul 16, 2024
08358b2
Remove testing temp output files.
wang2yn84 Jul 16, 2024
0c0162e
Fix the llama 70b output accuracy resulting from gqa.
wang2yn84 Jul 16, 2024
d607920
Fixes the attention output slicing issue when not using flash attenti…
wang2yn84 Jul 17, 2024
a59338c
Fix the pallas kernel OOB issue
wang2yn84 Jul 18, 2024
181a809
Fix tests; Fix lint issues;
wang2yn84 Jul 18, 2024
766e14c
Fix the interactive script.
wang2yn84 Jul 18, 2024
03d9ba6
Fix lint errors.
wang2yn84 Jul 19, 2024
331aabf
Fix errors.
wang2yn84 Jul 19, 2024
0f10636
Fix the comments.
wang2yn84 Jul 19, 2024
ba61830
Fix based on comments; Fix all the unit tests.
wang2yn84 Jul 19, 2024
d36ac81
Fix the remaining pylint errors.
wang2yn84 Jul 19, 2024
93451db
Default ring buffer back to true so that all the test_run_server and …
wang2yn84 Jul 19, 2024
e8f1469
Fix all the lint errors.
wang2yn84 Jul 19, 2024
134247c
Remove the deps/JetStream changes.
wang2yn84 Jul 20, 2024
d36641f
Fix merge errors, fix lint errors.
wang2yn84 Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file")


def run_prefill_time(engine, params, decode_state, seqlen):
def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
"""Run prefill and measure time."""
metadata = engine.get_tokenizer()
tokenizer = engine.build_tokenizer(metadata)
Expand All @@ -53,15 +53,20 @@ def run_prefill_time(engine, params, decode_state, seqlen):
nums = 5
start = time.perf_counter()
for i in range(nums):
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
jax.profiler.start_trace(FLAGS.profiling_output)
profiler_started = True

prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
prefill_result, decode_state, slot=jnp.int32(i)
)
jax.block_until_ready(decode_state)

end = time.perf_counter()
return (end - start) / nums, decode_state
return (end - start) / nums, decode_state, profiler_started


MAXTEXT_PREFILL = {
Expand All @@ -86,9 +91,10 @@ def main(argv):
prefill_times = {}

decode_state = engine.init_decode_state()
profiler_started = False
for batch, _ in MAXTEXT_PREFILL.items():
runtime, decode_state = run_prefill_time(
engine, params, decode_state, batch
runtime, decode_state, profiler_started = run_prefill_time(
engine, params, decode_state, batch, profiler_started
)
prefill_times[batch] = runtime

Expand All @@ -103,10 +109,12 @@ def main(argv):

profiling_output = FLAGS.profiling_output
print("======= decode starting ===")

dec_times = []
for i in range(10):
if profiling_output and i == 7:
if profiling_output and i == 7 and not profiler_started:
jax.profiler.start_trace(profiling_output)
profiler_started = True
start = time.perf_counter()
# pylint: disable-next=all
decode_state, sampled_tokens = engine.generate(params, decode_state)
Expand All @@ -116,7 +124,7 @@ def main(argv):
dec_times.append(end - start)
print(i, "decode time", (end - start))

if profiling_output:
if profiler_started:
jax.profiler.stop_trace()

print("prefill ", prefill_times)
Expand Down
Loading
Loading