diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 539543f5e7..29e8ba28e7 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -532,7 +532,11 @@ def _apply_learned_pos_embed( pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: # Resize if needed - directly using F.interpolate - _interp_size = to_2tuple(max(grid_size)) if self.pos_embed_ar_preserving else grid_size + if self.pos_embed_ar_preserving: + L = max(grid_size) + _interp_size = L, L + else: + _interp_size = grid_size pos_embed_flat = F.interpolate( self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W size=_interp_size, @@ -968,7 +972,7 @@ def __init__( cfg: Model configuration. If None, uses default NaFlexVitCfg. in_chans: Number of input image channels. num_classes: Number of classification classes. - img_size: Input image size for backwards compatibility. + img_size: Input image size (for backwards compatibility with classic vit). **kwargs: Additional config parameters to override cfg values. """ super().__init__() @@ -1523,9 +1527,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: return { 'url': url, 'num_classes': 1000, - 'input_size': (3, 256, 256), + 'input_size': (3, 384, 384), 'pool_size': None, - 'crop_pct': 0.95, + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, @@ -1537,11 +1541,19 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: default_cfgs = generate_default_cfgs({ - 'naflexvit_base_patch16_gap': _cfg(), - 'naflexvit_base_patch16_map': _cfg(), - - 'naflexvit_base_patch16_siglip': _cfg(), - 'naflexvit_so400m_patch16_siglip': _cfg(), + 'naflexvit_base_patch16_gap.e300_s576_in1k': _cfg( + hf_hub_id='timm/', + ), + 'naflexvit_base_patch16_par_gap.e300_s576_in1k': _cfg( + hf_hub_id='timm/', + ), + 'naflexvit_base_patch16_parfac_gap.e300_s576_in1k': _cfg( + hf_hub_id='timm/', + ), + 'naflexvit_base_patch16_map.untrained': _cfg(), + + 'naflexvit_base_patch16_siglip.untrained': _cfg(), + 'naflexvit_so400m_patch16_siglip.untrained': _cfg(), }) @@ -1623,6 +1635,45 @@ def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: return model +@register_model +def naflexvit_base_patch16_par_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-Base with NaFlex functionality, aspect preserving pos embed, global average pooling. + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + init_values=1e-5, + pos_embed_ar_preserving=True, + global_pool='avg', + reg_tokens=4, + fc_norm=True, + ) + model = _create_naflexvit('naflexvit_base_patch16_par_gap', pretrained=pretrained, cfg=cfg, **kwargs) + return model + + +@register_model +def naflexvit_base_patch16_parfac_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-Base with NaFlex functionality, aspect preserving & factorized pos embed, global average pooling. + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + init_values=1e-5, + pos_embed_ar_preserving=True, + pos_embed='factorized', + global_pool='avg', + reg_tokens=4, + fc_norm=True, + ) + model = _create_naflexvit('naflexvit_base_patch16_parfac_gap', pretrained=pretrained, cfg=cfg, **kwargs) + return model + + @register_model def naflexvit_base_patch16_map(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-Base with NaFlex functionality and MAP attention pooling.