Skip to content

Commit 2c32e1d

Browse files
authored
Use prefix in lfm2 output location (#1369)
1 parent 1d08e91 commit 2c32e1d

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/configs.js

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ function getNormalizedConfig(config) {
264264
*/
265265
export function getCacheShapes(config, options) {
266266
if (config.model_type === 'lfm2') {
267+
const pkv_prefix = options?.prefix ?? 'past_key_values';
268+
const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past';
269+
267270
// Custom caching mechanism for LFM2
268271
/** @type {Record<string, number[]>} */
269272
const cache_values = {};
@@ -274,10 +277,10 @@ export function getCacheShapes(config, options) {
274277
for (let i = 0; i < layer_types.length; ++i) {
275278
if (layer_types[i] === 'full_attention') {
276279
for (const kv of ['key', 'value']) {
277-
cache_values[`past_key_values.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, head_dim];
280+
cache_values[`${pkv_prefix}.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, head_dim];
278281
}
279282
} else if (layer_types[i] === 'conv') {
280-
cache_values[`past_conv.${i}`] = [batch_size, hidden_size, conv_L_cache];
283+
cache_values[`${conv_prefix}_conv.${i}`] = [batch_size, hidden_size, conv_L_cache];
281284
} else {
282285
throw new Error(`Unsupported layer type: ${layer_types[i]}`);
283286
}

0 commit comments

Comments
 (0)