Skip to content

3.11.6 test failed tests.schedulers.test_scheduler_flax.FlaxDDPMSchedulerTest.test_full_loop_no_noise #138

@SamuelMarks

Description

@SamuelMarks
tests/schedulers/test_scheduler_flax.py:304 (FlaxDDPMSchedulerTest.test_full_loop_no_noise)
Array(3.7847595, dtype=float32) != 0.01

Expected :0.01
Actual   :Array(3.7847595, dtype=float32)
self = <test_scheduler_flax.FlaxDDPMSchedulerTest testMethod=test_full_loop_no_noise>

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)
        state = scheduler.create_state()
    
        num_trained_timesteps = len(scheduler)
    
        model = self.dummy_model()
        sample = self.dummy_sample_deter
        key1, key2 = random.split(random.PRNGKey(0))
    
        for t in reversed(range(num_trained_timesteps)):
            # 1. predict noise residual
            residual = model(sample, t)
    
            # 2. predict previous mean of sample x_t-1
            output = scheduler.step(state, residual, t, sample, key1)
            pred_prev_sample = output.prev_sample
            state = output.state
            key1, key2 = random.split(key2)
    
            # if t > 0:
            #     noise = self.dummy_sample_deter
            #     variance = scheduler.get_variance(t) ** (0.5) * noise
            #
            # sample = pred_prev_sample + variance
            sample = pred_prev_sample
    
        result_sum = jnp.sum(jnp.abs(sample))
        result_mean = jnp.mean(jnp.abs(sample))
    
        if jax_device == "tpu":
            assert abs(result_sum - 251.26245) < 1e-2
            assert abs(result_mean - 0.32716465) < 1e-3
        else:
>           assert abs(result_sum - 255.1113) < 1e-2
E           assert Array(3.7847595, dtype=float32) < 0.01
E            +  where Array(3.7847595, dtype=float32) = abs((Array(251.32654, dtype=float32) - 255.1113))

schedulers/test_scheduler_flax.py:341: AssertionError

Running this without a TPU or GPU; but an M3 Pro.

Planning on going through all your tests and dependencies until 3.10, 3.11, 3.12, 3.13 are supported in addition to your existent 3.8 & 3.9 support.

PS: Your grain-nightly dependency doesn't seem to support 3.8, 3.9:

ERROR: Ignored the following versions that require a different python version: 0.0.1 Requires-Python >=3.10; 0.0.2 Requires-Python >=3.10; 0.0.3 Requires-Python >=3.10; 0.0.4 Requires-Python >=3.10
ERROR: Could not find a version that satisfies the requirement grain-nightly (from versions: none)
ERROR: No matching distribution found for grain-nightly

Is your setup.py up-to-date? - What Python [CPython] versions are you testing on?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions