You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
>>>importpyhf>>>pyhf.set_backend("jax")
>>>m=pyhf.simplemodels.hepdata_like([10], [15], [5])
>>>pyhf.infer.mle.fit([12.5], m)
crashes like so
with a possible hint?
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using jnp together with import jax.numpy as jnp rather than using np via import numpy as np. If this error arises on a line that involves array indexing, like x[idx], it may be that the array being indexed x is a raw numpy.ndarray while the indices idx are a JAX Tracer instance; in that case, you can instead write jax.device_put(x)[idx].