Skip to content

Update default naflexvit positional embedding interpolation mode to bilinear #2543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

drhead
Copy link

@drhead drhead commented Jul 9, 2025

The default positional embedding interpolation mode for the NaFlex SigLIP2 model doesn't match what is used for SigLIP2 in its official implementations. In fact, looking at the Transformers implementation, there's not even an option to have it as anything but bilinear: https://github.com/huggingface/transformers/blob/2781ad092dad77ff554cb70ec130b97e44cfba78/src/transformers/models/siglip2/modeling_siglip2.py#L174

Bilinear is probably the more appropriate default since it is what is used in official implementations of SigLIP 2. Having it on the wrong mode causes significant deviation in cosine similarity of outputs between the Transformers and TIMM implementations of SigLIP 2, but setting the mode to bilinear and using an input image that doesn't need resizing (to avoid preprocessing discrepancies) results in identical intermediate and final outputs between the two implementations.

@redhottensors
Copy link

Haha. Looks like I was a bit faster by being less helpful. #2542

@redhottensors
Copy link

I do not think that this is the correct place to make the change though. See the issue I opened.

@rwightman
Copy link
Collaborator

@drhead you cannot change the default globally, only the config for siglip specific models can be changed.

I would like some convincing data for the change. I do not care about differences between the transformers and timm impl, only between the original jax models and timm... if torch bicubic is not convincingly different from jax bilinear than I will leave it as bicubic as I've found it to be more robust overall ... from my zero-shot eval comparisons there wasn't a convincing argument either way.

@redhottensors
Copy link

@drhead you cannot change the default globally, only the config for siglip specific models can be changed.

I would like some convincing data for the change. I do not care about differences between the transformers and timm impl, only between the original jax models and timm... if torch bicubic is not convincingly different from jax bilinear than I will leave it as bicubic as I've found it to be more robust overall ... from my zero-shot eval comparisons there wasn't a convincing argument either way.

Sorry, our team has higher priorities than getting this set up with JAX.

@rwightman
Copy link
Collaborator

@redhottensors okay, well until someone has time to verify I'll stick with my original analysis that bicubic is the best choice. There are differences between torch and jax interpolation modes of the same 'type'. I evaluated both in zero-shot and bicubic appeared to 'win' across a few scenarios.

It is expected that the timm and transformers impl would be very very close numerically (aside from this difference), but what's more important in these decision is what works best in comparison to the original in numerous downstream use cases. I will take another look when I get back to integrating with OpenCLIP.

@drhead
Copy link
Author

drhead commented Jul 9, 2025

closing since I would agree this is a documentation bug as #2542 was updated to reflect. at least as far as what I know on this issue goes:

  • I know torch's interpolate algorithm is a bit of an odd outlier since it does both axes in one kernel, where most use a separable kernel working on one axis at a time. JAX as far as I can tell from a cursory look is doing a naive matmul which while very inefficient does work fine. Nvidia DALI uses a separable kernel and is significantly faster than pytorch's interpolate. I never looked in much detail at how torch's implementation works, but I suspect it'd have to be more than just rounding errors here.

  • I recall a paper some time ago which I cannot for the life of me find, which was on resizing image inputs with a learned resampling kernel, which helped some classification problems. For that reason I wouldn't be surprised if something other than bicubic was better but it rightfully shouldn't be behaving better than what the model was trained on, unless we're talking about training with all of the model weights unfrozen rather than just the head.

@drhead drhead closed this Jul 9, 2025
@rwightman
Copy link
Collaborator

@drhead yup, agreed with a above, but I've pointed this out to others before, 'what the model was trained on' is the JAX implementation of 'bilinear', not the torch impl of 'bilinear', and when considering the image interpolation preprocessing we actually have for the original:

  • tf.image.resize 'bilinear' for the input image
  • jax.image.scale_and_translate + 'bilinear' for the position embedding.

And in torch we will have PIL or torchvision impl of ? + torch.nn.functional.interpolate of ? ... simply matching strings doesn't necessarily get you the best match or end result given that all of the implementations differ. If the implementations of bilinear -> bilinear across frameworks are sufficiently different, than I usually use bicubic as it tends to behave as good (given the differences) or better across more size ranges, scenarious.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants