Ensure that all the code works as expected with `torch.compile`. It would be good to include this in the unit tests as well.