-
Notifications
You must be signed in to change notification settings - Fork 1.3k
timestep scheduling with np.linspace #8623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Signed-off-by: ytl0623 <david89062388@gmail.com>
WalkthroughThe diff changes how DDPM and DDIM compute inference timesteps in set_timesteps: instead of using np.arange * step_ratio, they use numpy.linspace from (num_train_timesteps - 1) down to 0 with num_inference_steps, round and cast to int64, then convert to a torch tensor. DDIM also adds validation for steps_offset to be within [0, num_train_timesteps). Public APIs remain unchanged. The selected discrete timesteps and their ordering are altered and the endpoint is ensured to be included. Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Areas to review:
Pre-merge checks and finishing touches✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/schedulers/ddim.py (1)
120-126: Validation logic is now incorrect with linspace approach.
step_ratiois no longer used in timestep generation (line 130 useslinspace). The validation at lines 121-126 incorrectly assumesstep_ratiospacing. With linspace starting atnum_train_timesteps - 1, the constraint should preventsteps_offsetfrom causing timesteps to exceed valid indices[0, num_train_timesteps - 1].Required fix: replace the validation to ensure
steps_offset < 1when using linspace fromnum_train_timesteps - 1, or adjust the linspace start point to accommodate the offset.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/schedulers/ddim.py(1 hunks)monai/networks/schedulers/ddpm.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/schedulers/ddpm.pymonai/networks/schedulers/ddim.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
🔇 Additional comments (2)
monai/networks/schedulers/ddpm.py (1)
128-129: Fix correctly addresses the endpoint issue and tests pass.The linspace approach generates timesteps from
num_train_timesteps - 1to 0, ensuring both endpoints are included and all values are valid array indices for the alpha arrays used in thestep()method. Existing tests verify functionality and error handling.monai/networks/schedulers/ddim.py (1)
179-179: Spacing assumption needs verification.The concern is valid:
self.timestepsuses linspace with rounding (line 130), creating non-uniform gaps. Fixed offset arithmetic (line 179:timestep - step_ratio) doesn't guarantee alignment with actual previous scheduled timestep. If gaps vary,prev_timestepmay misalign, causing incorrect alpha lookups.However, impact on diffusion quality (whether approximation is acceptable per DDIM paper or requires dynamic lookup from
self.timesteps) requires verification against:
- Reference DDIM implementations (HuggingFace, Stable Diffusion, etc.)
- Whether existing tests catch quality degradation
Signed-off-by: ytl0623 <david89062388@gmail.com>
Signed-off-by: ytl0623 <david89062388@gmail.com>
Signed-off-by: ytl0623 <david89062388@gmail.com>
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/schedulers/ddim.py (1)
121-126: Validation logic is incorrect for the new implementation.With the linspace approach,
max_timestep = (num_train_timesteps - 1 - steps_offset) + steps_offset = num_train_timesteps - 1regardless of the relationship betweensteps_offsetandstep_ratio. The actual constraint is0 <= steps_offset < num_train_timesteps.Apply this diff:
- step_ratio = self.num_train_timesteps // self.num_inference_steps - if self.steps_offset >= step_ratio: + if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps: raise ValueError( - f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " - f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" - f" the max train timestep." + f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps})." )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/schedulers/ddim.py(1 hunks)monai/networks/schedulers/ddpm.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- monai/networks/schedulers/ddpm.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/schedulers/ddim.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (codeformat)
🔇 Additional comments (1)
monai/networks/schedulers/ddim.py (1)
128-130: Correctly fixes the endpoint issue.The subtract-then-add pattern ensures max timestep =
num_train_timesteps - 1and min timestep =steps_offset, keeping all indices valid. The linspace approach guarantees the endpoint is reached.
monai/networks/schedulers/ddim.py
Outdated
| timesteps = np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps).round().astype(np.int64) | ||
| self.timesteps = torch.from_numpy(timesteps).to(device) | ||
| self.timesteps += self.steps_offset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify test coverage for the new timestep computation.
The linspace approach changes which discrete training timesteps are selected. Ensure tests validate that: (1) max timestep = num_train_timesteps - 1, (2) min timestep = steps_offset, (3) endpoint behavior differs from the old arange-based method.
As per coding guidelines.
🏁 Script executed:
#!/bin/bash
# Find test files for DDIM scheduler and check if set_timesteps is tested
fd -e py test.*ddim -x rg -l "set_timesteps"Length of output: 50
🏁 Script executed:
#!/bin/bash
# Examine the test file to see what tests exist for set_timesteps
rg -n -A 20 "set_timesteps|def test.*timestep" ./tests/networks/schedulers/test_scheduler_ddim.py -C 5Length of output: 2393
🏁 Script executed:
#!/bin/bash
# Examine the implementation to understand the timestep computation
rg -n -B 5 -A 15 "def set_timesteps" ./monai/networks/schedulers/ddim.pyLength of output: 1127
🏁 Script executed:
#!/bin/bash
# Get complete set_timesteps implementation
sed -n '104,135p' ./monai/networks/schedulers/ddim.pyLength of output: 1815
🏁 Script executed:
#!/bin/bash
# Check git history to see what the previous implementation looked like
git -C . log -p --follow -S "set_timesteps" -- ./monai/networks/schedulers/ddim.py | head -100Length of output: 5302
🏁 Script executed:
#!/bin/bash
# Look for arange in git history to find the previous timestep computation
git -C . log --all -p -- ./monai/networks/schedulers/ddim.py | grep -B 10 -A 10 "arange" | head -50Length of output: 2470
🏁 Script executed:
#!/bin/bash
# Show the commit that introduced linspace vs what was before
git -C . log --oneline --all -- ./monai/networks/schedulers/ddim.py | head -20Length of output: 250
🏁 Script executed:
#!/bin/bash
# Check the most recent commits that dealt with timestep issues
git -C . show aa3cce60 --statLength of output: 331
🏁 Script executed:
#!/bin/bash
# Check the latest refactor commit
git -C . show 616f5f23 --statLength of output: 325
🏁 Script executed:
#!/bin/bash
# Get full diff for the latest refactor
git -C . show 616f5f23 -- ./monai/networks/schedulers/ddim.pyLength of output: 1051
Tests do not validate timestep boundary values or offset behavior.
Current tests only check counts (num_inference_steps, len(timesteps)), not actual values. Add tests that verify: (1) max(timesteps) == num_train_timesteps - 1, (2) min(timesteps) == steps_offset, (3) timestep spacing from linspace method. This is critical given recent modifications to handle steps_offset (commits aa3cce6, 616f5f2).
🤖 Prompt for AI Agents
In monai/networks/schedulers/ddim.py around lines 128 to 130, tests currently
only assert counts but not actual timestep values; add unit tests that construct
the scheduler with known num_train_timesteps and steps_offset and assert: (1)
max(timesteps) == num_train_timesteps - 1, (2) min(timesteps) == steps_offset,
and (3) that the spacing matches the np.linspace behavior (i.e., consecutive
differences correspond to the rounded linspace sequence) by computing the
expected numpy linspace sequence with the same rounding/astype logic and
comparing elementwise to the scheduler's torch timesteps on the same device;
ensure tests cover multiple configurations (different num_inference_steps and
nonzero steps_offset).
… steps_offset) + steps_offset = num_train_timesteps - 1 regardless of the relationship between steps_offset and step_ratio. The actual constraint is 0 <= steps_offset < num_train_timesteps. Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
monai/networks/schedulers/ddim.py (1)
120-121: Validation logic is correct.The range check ensures the linspace start is non-negative and the final max timestep stays at
num_train_timesteps - 1.Static analysis suggests defining a custom exception or shortening the message to comply with TRY003, though this is a minor style concern.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/schedulers/ddim.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/schedulers/ddim.py
🪛 Ruff (0.14.3)
monai/networks/schedulers/ddim.py
121-121: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (1)
monai/networks/schedulers/ddim.py (1)
123-129: Timestep computation is correct and validated by tests.The linspace approach with rounding produces nearly-uniform spacing (gaps of 10–11). The
step()method's prev_timestep approximation (gap =num_train_timesteps // num_inference_steps) works correctly despite minor non-uniformity. Thetest_full_timestep_looptest validates numerical correctness end-to-end, confirmingalphas_cumprodindexing is sound.
|
Hi @virginiafdez, @KumoLiu, @Nic-Ma and @ericspod, Sorry to bother. Thanks in advance! |
Fixes #8600
Description
The
np.linspaceapproach generates a descending array that starts exactly at 999 and ends exactly at 0 (after rounding), ensuring the scheduler samples the entire intended trajectory.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.