121 lines
4.5 KiB
Python
121 lines
4.5 KiB
Python
|
import comfy.utils
|
||
|
import folder_paths
|
||
|
import torch
|
||
|
import logging
|
||
|
|
||
|
def load_hypernetwork_patch(path, strength):
|
||
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||
|
activation_func = sd.get('activation_func', 'linear')
|
||
|
is_layer_norm = sd.get('is_layer_norm', False)
|
||
|
use_dropout = sd.get('use_dropout', False)
|
||
|
activate_output = sd.get('activate_output', False)
|
||
|
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||
|
|
||
|
valid_activation = {
|
||
|
"linear": torch.nn.Identity,
|
||
|
"relu": torch.nn.ReLU,
|
||
|
"leakyrelu": torch.nn.LeakyReLU,
|
||
|
"elu": torch.nn.ELU,
|
||
|
"swish": torch.nn.Hardswish,
|
||
|
"tanh": torch.nn.Tanh,
|
||
|
"sigmoid": torch.nn.Sigmoid,
|
||
|
"softsign": torch.nn.Softsign,
|
||
|
"mish": torch.nn.Mish,
|
||
|
}
|
||
|
|
||
|
if activation_func not in valid_activation:
|
||
|
logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
|
||
|
return None
|
||
|
|
||
|
out = {}
|
||
|
|
||
|
for d in sd:
|
||
|
try:
|
||
|
dim = int(d)
|
||
|
except:
|
||
|
continue
|
||
|
|
||
|
output = []
|
||
|
for index in [0, 1]:
|
||
|
attn_weights = sd[dim][index]
|
||
|
keys = attn_weights.keys()
|
||
|
|
||
|
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||
|
layers = []
|
||
|
|
||
|
i = 0
|
||
|
while i < len(linears):
|
||
|
lin_name = linears[i]
|
||
|
last_layer = (i == (len(linears) - 1))
|
||
|
penultimate_layer = (i == (len(linears) - 2))
|
||
|
|
||
|
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||
|
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||
|
layers.append(layer)
|
||
|
if activation_func != "linear":
|
||
|
if (not last_layer) or (activate_output):
|
||
|
layers.append(valid_activation[activation_func]())
|
||
|
if is_layer_norm:
|
||
|
i += 1
|
||
|
ln_name = linears[i]
|
||
|
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
||
|
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
||
|
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
||
|
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
||
|
layers.append(ln)
|
||
|
if use_dropout:
|
||
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||
|
layers.append(torch.nn.Dropout(p=0.3))
|
||
|
i += 1
|
||
|
|
||
|
output.append(torch.nn.Sequential(*layers))
|
||
|
out[dim] = torch.nn.ModuleList(output)
|
||
|
|
||
|
class hypernetwork_patch:
|
||
|
def __init__(self, hypernet, strength):
|
||
|
self.hypernet = hypernet
|
||
|
self.strength = strength
|
||
|
def __call__(self, q, k, v, extra_options):
|
||
|
dim = k.shape[-1]
|
||
|
if dim in self.hypernet:
|
||
|
hn = self.hypernet[dim]
|
||
|
k = k + hn[0](k) * self.strength
|
||
|
v = v + hn[1](v) * self.strength
|
||
|
|
||
|
return q, k, v
|
||
|
|
||
|
def to(self, device):
|
||
|
for d in self.hypernet.keys():
|
||
|
self.hypernet[d] = self.hypernet[d].to(device)
|
||
|
return self
|
||
|
|
||
|
return hypernetwork_patch(out, strength)
|
||
|
|
||
|
class HypernetworkLoader:
|
||
|
@classmethod
|
||
|
def INPUT_TYPES(s):
|
||
|
return {"required": { "model": ("MODEL",),
|
||
|
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||
|
}}
|
||
|
RETURN_TYPES = ("MODEL",)
|
||
|
FUNCTION = "load_hypernetwork"
|
||
|
|
||
|
CATEGORY = "loaders"
|
||
|
|
||
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||
|
model_hypernetwork = model.clone()
|
||
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||
|
if patch is not None:
|
||
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||
|
return (model_hypernetwork,)
|
||
|
|
||
|
NODE_CLASS_MAPPINGS = {
|
||
|
"HypernetworkLoader": HypernetworkLoader
|
||
|
}
|