File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -264,6 +264,9 @@ function getNormalizedConfig(config) {
264
264
*/
265
265
export function getCacheShapes ( config , options ) {
266
266
if ( config . model_type === 'lfm2' ) {
267
+ const pkv_prefix = options ?. prefix ?? 'past_key_values' ;
268
+ const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past' ;
269
+
267
270
// Custom caching mechanism for LFM2
268
271
/** @type {Record<string, number[]> } */
269
272
const cache_values = { } ;
@@ -274,10 +277,10 @@ export function getCacheShapes(config, options) {
274
277
for ( let i = 0 ; i < layer_types . length ; ++ i ) {
275
278
if ( layer_types [ i ] === 'full_attention' ) {
276
279
for ( const kv of [ 'key' , 'value' ] ) {
277
- cache_values [ `past_key_values .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , head_dim ] ;
280
+ cache_values [ `${ pkv_prefix } .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , head_dim ] ;
278
281
}
279
282
} else if ( layer_types [ i ] === 'conv' ) {
280
- cache_values [ `past_conv .${ i } ` ] = [ batch_size , hidden_size , conv_L_cache ] ;
283
+ cache_values [ `${ conv_prefix } _conv .${ i } ` ] = [ batch_size , hidden_size , conv_L_cache ] ;
281
284
} else {
282
285
throw new Error ( `Unsupported layer type: ${ layer_types [ i ] } ` ) ;
283
286
}
You can’t perform that action at this time.
0 commit comments