-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
tests/schedulers/test_scheduler_flax.py:304 (FlaxDDPMSchedulerTest.test_full_loop_no_noise)
Array(3.7847595, dtype=float32) != 0.01
Expected :0.01
Actual :Array(3.7847595, dtype=float32)
self = <test_scheduler_flax.FlaxDDPMSchedulerTest testMethod=test_full_loop_no_noise>
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
num_trained_timesteps = len(scheduler)
model = self.dummy_model()
sample = self.dummy_sample_deter
key1, key2 = random.split(random.PRNGKey(0))
for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
output = scheduler.step(state, residual, t, sample, key1)
pred_prev_sample = output.prev_sample
state = output.state
key1, key2 = random.split(key2)
# if t > 0:
# noise = self.dummy_sample_deter
# variance = scheduler.get_variance(t) ** (0.5) * noise
#
# sample = pred_prev_sample + variance
sample = pred_prev_sample
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
if jax_device == "tpu":
assert abs(result_sum - 251.26245) < 1e-2
assert abs(result_mean - 0.32716465) < 1e-3
else:
> assert abs(result_sum - 255.1113) < 1e-2
E assert Array(3.7847595, dtype=float32) < 0.01
E + where Array(3.7847595, dtype=float32) = abs((Array(251.32654, dtype=float32) - 255.1113))
schedulers/test_scheduler_flax.py:341: AssertionError
Running this without a TPU or GPU; but an M3 Pro.
Planning on going through all your tests and dependencies until 3.10, 3.11, 3.12, 3.13 are supported in addition to your existent 3.8 & 3.9 support.
PS: Your grain-nightly
dependency doesn't seem to support 3.8, 3.9:
ERROR: Ignored the following versions that require a different python version: 0.0.1 Requires-Python >=3.10; 0.0.2 Requires-Python >=3.10; 0.0.3 Requires-Python >=3.10; 0.0.4 Requires-Python >=3.10
ERROR: Could not find a version that satisfies the requirement grain-nightly (from versions: none)
ERROR: No matching distribution found for grain-nightly
Is your setup.py
up-to-date? - What Python [CPython] versions are you testing on?
Metadata
Metadata
Assignees
Labels
No labels