-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
I've noticed more and more timm
backbones being added here, which is great, but a lot of the effort is currently duplicating some features of timm
, ie tracking channel numbers, modifying the networks, etc.
timm
has a features_only
arg in the model factory that will return a model setup as a backbone to produce pyramid features. It has a .features_info attribute you can query to understand what the channels of each output, the approx reduction factor is, etc.
I've adapted the unet and deeplab impl here in the past to use this successfully, although it was quick hack and train work, nothing to serve as a clean example.
If this was supported, any timm model (vit excluded right now) can be used as a backbone in generic fashion, just by model name string passed to creation fn, possibly a small config mapping of model types to index specificiations (some models have slightly different out_indices
alignment to strides if they happen be a stride 64 model, or don't have a stride=2 feature, etc). All tap points are the latest possible point for a given feature map stride. Some, but not all of the timm backbones also support an output_stride=
arg that will dilate the blocks appropriately for 8, 16 network strides.
Some references:
- https://rwightman.github.io/pytorch-image-models/feature_extraction/#multi-scale-feature-maps-feature-pyramid
- https://github.com/rwightman/efficientdet-pytorch/blob/92bb66fd0cf91d0e23fe8b10cba97e2f0bb9884f/effdet/efficientdet.py#L554-L569
For most of the models, the featuers are extracted by flattening part of the backbone model via wrapper. A few models where the feature taps are embedded deep within the model use hooks, which causes some issues with torchscript but that will likely be fixed soon in PyTorch.