55 lines
2.5 KiB
Python
55 lines
2.5 KiB
Python
|
import torch
|
||
|
import comfy.utils
|
||
|
|
||
|
class PatchModelAddDownscale:
|
||
|
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||
|
@classmethod
|
||
|
def INPUT_TYPES(s):
|
||
|
return {"required": { "model": ("MODEL",),
|
||
|
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
||
|
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||
|
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||
|
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
||
|
"downscale_method": (s.upscale_methods,),
|
||
|
"upscale_method": (s.upscale_methods,),
|
||
|
}}
|
||
|
RETURN_TYPES = ("MODEL",)
|
||
|
FUNCTION = "patch"
|
||
|
|
||
|
CATEGORY = "_for_testing"
|
||
|
|
||
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||
|
model_sampling = model.get_model_object("model_sampling")
|
||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||
|
|
||
|
def input_block_patch(h, transformer_options):
|
||
|
if transformer_options["block"][1] == block_number:
|
||
|
sigma = transformer_options["sigmas"][0].item()
|
||
|
if sigma <= sigma_start and sigma >= sigma_end:
|
||
|
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||
|
return h
|
||
|
|
||
|
def output_block_patch(h, hsp, transformer_options):
|
||
|
if h.shape[2] != hsp.shape[2]:
|
||
|
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||
|
return h, hsp
|
||
|
|
||
|
m = model.clone()
|
||
|
if downscale_after_skip:
|
||
|
m.set_model_input_block_patch_after_skip(input_block_patch)
|
||
|
else:
|
||
|
m.set_model_input_block_patch(input_block_patch)
|
||
|
m.set_model_output_block_patch(output_block_patch)
|
||
|
return (m, )
|
||
|
|
||
|
NODE_CLASS_MAPPINGS = {
|
||
|
"PatchModelAddDownscale": PatchModelAddDownscale,
|
||
|
}
|
||
|
|
||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
|
# Sampling
|
||
|
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
||
|
}
|