603 lines
23 KiB
Python
603 lines
23 KiB
Python
|
import os
|
||
|
|
||
|
from transformers import CLIPTokenizer
|
||
|
import comfy.ops
|
||
|
import torch
|
||
|
import traceback
|
||
|
import zipfile
|
||
|
from . import model_management
|
||
|
import comfy.clip_model
|
||
|
import json
|
||
|
import logging
|
||
|
import numbers
|
||
|
|
||
|
def gen_empty_tokens(special_tokens, length):
|
||
|
start_token = special_tokens.get("start", None)
|
||
|
end_token = special_tokens.get("end", None)
|
||
|
pad_token = special_tokens.get("pad")
|
||
|
output = []
|
||
|
if start_token is not None:
|
||
|
output.append(start_token)
|
||
|
if end_token is not None:
|
||
|
output.append(end_token)
|
||
|
output += [pad_token] * (length - len(output))
|
||
|
return output
|
||
|
|
||
|
class ClipTokenWeightEncoder:
|
||
|
def encode_token_weights(self, token_weight_pairs):
|
||
|
to_encode = list()
|
||
|
max_token_len = 0
|
||
|
has_weights = False
|
||
|
for x in token_weight_pairs:
|
||
|
tokens = list(map(lambda a: a[0], x))
|
||
|
max_token_len = max(len(tokens), max_token_len)
|
||
|
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||
|
to_encode.append(tokens)
|
||
|
|
||
|
sections = len(to_encode)
|
||
|
if has_weights or sections == 0:
|
||
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||
|
|
||
|
o = self.encode(to_encode)
|
||
|
out, pooled = o[:2]
|
||
|
|
||
|
if pooled is not None:
|
||
|
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||
|
else:
|
||
|
first_pooled = pooled
|
||
|
|
||
|
output = []
|
||
|
for k in range(0, sections):
|
||
|
z = out[k:k+1]
|
||
|
if has_weights:
|
||
|
z_empty = out[-1]
|
||
|
for i in range(len(z)):
|
||
|
for j in range(len(z[i])):
|
||
|
weight = token_weight_pairs[k][j][1]
|
||
|
if weight != 1.0:
|
||
|
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||
|
output.append(z)
|
||
|
|
||
|
if (len(output) == 0):
|
||
|
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||
|
else:
|
||
|
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||
|
|
||
|
if len(o) > 2:
|
||
|
extra = {}
|
||
|
for k in o[2]:
|
||
|
v = o[2][k]
|
||
|
if k == "attention_mask":
|
||
|
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||
|
extra[k] = v
|
||
|
|
||
|
r = r + (extra,)
|
||
|
return r
|
||
|
|
||
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||
|
LAYERS = [
|
||
|
"last",
|
||
|
"pooled",
|
||
|
"hidden"
|
||
|
]
|
||
|
def __init__(self, device="cpu", max_length=77,
|
||
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||
|
super().__init__()
|
||
|
assert layer in self.LAYERS
|
||
|
|
||
|
if textmodel_json_config is None:
|
||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||
|
|
||
|
with open(textmodel_json_config) as f:
|
||
|
config = json.load(f)
|
||
|
|
||
|
operations = model_options.get("custom_operations", None)
|
||
|
scaled_fp8 = None
|
||
|
|
||
|
if operations is None:
|
||
|
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||
|
if scaled_fp8 is not None:
|
||
|
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||
|
else:
|
||
|
operations = comfy.ops.manual_cast
|
||
|
|
||
|
self.operations = operations
|
||
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||
|
if scaled_fp8 is not None:
|
||
|
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||
|
|
||
|
self.num_layers = self.transformer.num_layers
|
||
|
|
||
|
self.max_length = max_length
|
||
|
if freeze:
|
||
|
self.freeze()
|
||
|
self.layer = layer
|
||
|
self.layer_idx = None
|
||
|
self.special_tokens = special_tokens
|
||
|
|
||
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||
|
self.enable_attention_masks = enable_attention_masks
|
||
|
self.zero_out_masked = zero_out_masked
|
||
|
|
||
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||
|
self.return_projected_pooled = return_projected_pooled
|
||
|
self.return_attention_masks = return_attention_masks
|
||
|
|
||
|
if layer == "hidden":
|
||
|
assert layer_idx is not None
|
||
|
assert abs(layer_idx) < self.num_layers
|
||
|
self.set_clip_options({"layer": layer_idx})
|
||
|
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||
|
|
||
|
def freeze(self):
|
||
|
self.transformer = self.transformer.eval()
|
||
|
#self.train = disabled_train
|
||
|
for param in self.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
def set_clip_options(self, options):
|
||
|
layer_idx = options.get("layer", self.layer_idx)
|
||
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||
|
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||
|
self.layer = "last"
|
||
|
else:
|
||
|
self.layer = "hidden"
|
||
|
self.layer_idx = layer_idx
|
||
|
|
||
|
def reset_clip_options(self):
|
||
|
self.layer = self.options_default[0]
|
||
|
self.layer_idx = self.options_default[1]
|
||
|
self.return_projected_pooled = self.options_default[2]
|
||
|
|
||
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||
|
out_tokens = []
|
||
|
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||
|
embedding_weights = []
|
||
|
|
||
|
for x in tokens:
|
||
|
tokens_temp = []
|
||
|
for y in x:
|
||
|
if isinstance(y, numbers.Integral):
|
||
|
tokens_temp += [int(y)]
|
||
|
else:
|
||
|
if y.shape[0] == current_embeds.weight.shape[1]:
|
||
|
embedding_weights += [y]
|
||
|
tokens_temp += [next_new_token]
|
||
|
next_new_token += 1
|
||
|
else:
|
||
|
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
|
||
|
while len(tokens_temp) < len(x):
|
||
|
tokens_temp += [self.special_tokens["pad"]]
|
||
|
out_tokens += [tokens_temp]
|
||
|
|
||
|
n = token_dict_size
|
||
|
if len(embedding_weights) > 0:
|
||
|
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||
|
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
||
|
for x in embedding_weights:
|
||
|
new_embedding.weight[n] = x
|
||
|
n += 1
|
||
|
self.transformer.set_input_embeddings(new_embedding)
|
||
|
|
||
|
processed_tokens = []
|
||
|
for x in out_tokens:
|
||
|
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
||
|
|
||
|
return processed_tokens
|
||
|
|
||
|
def forward(self, tokens):
|
||
|
backup_embeds = self.transformer.get_input_embeddings()
|
||
|
device = backup_embeds.weight.device
|
||
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||
|
tokens = torch.LongTensor(tokens).to(device)
|
||
|
|
||
|
attention_mask = None
|
||
|
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||
|
attention_mask = torch.zeros_like(tokens)
|
||
|
end_token = self.special_tokens.get("end", -1)
|
||
|
for x in range(attention_mask.shape[0]):
|
||
|
for y in range(attention_mask.shape[1]):
|
||
|
attention_mask[x, y] = 1
|
||
|
if tokens[x, y] == end_token:
|
||
|
break
|
||
|
|
||
|
attention_mask_model = None
|
||
|
if self.enable_attention_masks:
|
||
|
attention_mask_model = attention_mask
|
||
|
|
||
|
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||
|
self.transformer.set_input_embeddings(backup_embeds)
|
||
|
|
||
|
if self.layer == "last":
|
||
|
z = outputs[0].float()
|
||
|
else:
|
||
|
z = outputs[1].float()
|
||
|
|
||
|
if self.zero_out_masked:
|
||
|
z *= attention_mask.unsqueeze(-1).float()
|
||
|
|
||
|
pooled_output = None
|
||
|
if len(outputs) >= 3:
|
||
|
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||
|
pooled_output = outputs[3].float()
|
||
|
elif outputs[2] is not None:
|
||
|
pooled_output = outputs[2].float()
|
||
|
|
||
|
extra = {}
|
||
|
if self.return_attention_masks:
|
||
|
extra["attention_mask"] = attention_mask
|
||
|
|
||
|
if len(extra) > 0:
|
||
|
return z, pooled_output, extra
|
||
|
|
||
|
return z, pooled_output
|
||
|
|
||
|
def encode(self, tokens):
|
||
|
return self(tokens)
|
||
|
|
||
|
def load_sd(self, sd):
|
||
|
return self.transformer.load_state_dict(sd, strict=False)
|
||
|
|
||
|
def parse_parentheses(string):
|
||
|
result = []
|
||
|
current_item = ""
|
||
|
nesting_level = 0
|
||
|
for char in string:
|
||
|
if char == "(":
|
||
|
if nesting_level == 0:
|
||
|
if current_item:
|
||
|
result.append(current_item)
|
||
|
current_item = "("
|
||
|
else:
|
||
|
current_item = "("
|
||
|
else:
|
||
|
current_item += char
|
||
|
nesting_level += 1
|
||
|
elif char == ")":
|
||
|
nesting_level -= 1
|
||
|
if nesting_level == 0:
|
||
|
result.append(current_item + ")")
|
||
|
current_item = ""
|
||
|
else:
|
||
|
current_item += char
|
||
|
else:
|
||
|
current_item += char
|
||
|
if current_item:
|
||
|
result.append(current_item)
|
||
|
return result
|
||
|
|
||
|
def token_weights(string, current_weight):
|
||
|
a = parse_parentheses(string)
|
||
|
out = []
|
||
|
for x in a:
|
||
|
weight = current_weight
|
||
|
if len(x) >= 2 and x[-1] == ')' and x[0] == '(':
|
||
|
x = x[1:-1]
|
||
|
xx = x.rfind(":")
|
||
|
weight *= 1.1
|
||
|
if xx > 0:
|
||
|
try:
|
||
|
weight = float(x[xx+1:])
|
||
|
x = x[:xx]
|
||
|
except:
|
||
|
pass
|
||
|
out += token_weights(x, weight)
|
||
|
else:
|
||
|
out += [(x, current_weight)]
|
||
|
return out
|
||
|
|
||
|
def escape_important(text):
|
||
|
text = text.replace("\\)", "\0\1")
|
||
|
text = text.replace("\\(", "\0\2")
|
||
|
return text
|
||
|
|
||
|
def unescape_important(text):
|
||
|
text = text.replace("\0\1", ")")
|
||
|
text = text.replace("\0\2", "(")
|
||
|
return text
|
||
|
|
||
|
def safe_load_embed_zip(embed_path):
|
||
|
with zipfile.ZipFile(embed_path) as myzip:
|
||
|
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
|
||
|
names.reverse()
|
||
|
for n in names:
|
||
|
with myzip.open(n) as myfile:
|
||
|
data = myfile.read()
|
||
|
number = len(data) // 4
|
||
|
length_embed = 1024 #sd2.x
|
||
|
if number < 768:
|
||
|
continue
|
||
|
if number % 768 == 0:
|
||
|
length_embed = 768 #sd1.x
|
||
|
num_embeds = number // length_embed
|
||
|
embed = torch.frombuffer(data, dtype=torch.float)
|
||
|
out = embed.reshape((num_embeds, length_embed)).clone()
|
||
|
del embed
|
||
|
return out
|
||
|
|
||
|
def expand_directory_list(directories):
|
||
|
dirs = set()
|
||
|
for x in directories:
|
||
|
dirs.add(x)
|
||
|
for root, subdir, file in os.walk(x, followlinks=True):
|
||
|
dirs.add(root)
|
||
|
return list(dirs)
|
||
|
|
||
|
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
|
||
|
i = 0
|
||
|
out_list = []
|
||
|
for k in embed:
|
||
|
if k.startswith(prefix) and k.endswith(suffix):
|
||
|
out_list.append(embed[k])
|
||
|
if len(out_list) == 0:
|
||
|
return None
|
||
|
|
||
|
return torch.cat(out_list, dim=0)
|
||
|
|
||
|
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||
|
if isinstance(embedding_directory, str):
|
||
|
embedding_directory = [embedding_directory]
|
||
|
|
||
|
embedding_directory = expand_directory_list(embedding_directory)
|
||
|
|
||
|
valid_file = None
|
||
|
for embed_dir in embedding_directory:
|
||
|
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
|
||
|
embed_dir = os.path.abspath(embed_dir)
|
||
|
try:
|
||
|
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
|
||
|
continue
|
||
|
except:
|
||
|
continue
|
||
|
if not os.path.isfile(embed_path):
|
||
|
extensions = ['.safetensors', '.pt', '.bin']
|
||
|
for x in extensions:
|
||
|
t = embed_path + x
|
||
|
if os.path.isfile(t):
|
||
|
valid_file = t
|
||
|
break
|
||
|
else:
|
||
|
valid_file = embed_path
|
||
|
if valid_file is not None:
|
||
|
break
|
||
|
|
||
|
if valid_file is None:
|
||
|
return None
|
||
|
|
||
|
embed_path = valid_file
|
||
|
|
||
|
embed_out = None
|
||
|
|
||
|
try:
|
||
|
if embed_path.lower().endswith(".safetensors"):
|
||
|
import safetensors.torch
|
||
|
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||
|
else:
|
||
|
if 'weights_only' in torch.load.__code__.co_varnames:
|
||
|
try:
|
||
|
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||
|
except:
|
||
|
embed_out = safe_load_embed_zip(embed_path)
|
||
|
else:
|
||
|
embed = torch.load(embed_path, map_location="cpu")
|
||
|
except Exception as e:
|
||
|
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||
|
return None
|
||
|
|
||
|
if embed_out is None:
|
||
|
if 'string_to_param' in embed:
|
||
|
values = embed['string_to_param'].values()
|
||
|
embed_out = next(iter(values))
|
||
|
elif isinstance(embed, list):
|
||
|
out_list = []
|
||
|
for x in range(len(embed)):
|
||
|
for k in embed[x]:
|
||
|
t = embed[x][k]
|
||
|
if t.shape[-1] != embedding_size:
|
||
|
continue
|
||
|
out_list.append(t.reshape(-1, t.shape[-1]))
|
||
|
embed_out = torch.cat(out_list, dim=0)
|
||
|
elif embed_key is not None and embed_key in embed:
|
||
|
embed_out = embed[embed_key]
|
||
|
else:
|
||
|
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
|
||
|
if embed_out is None:
|
||
|
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
|
||
|
if embed_out is None:
|
||
|
values = embed.values()
|
||
|
embed_out = next(iter(values))
|
||
|
return embed_out
|
||
|
|
||
|
class SDTokenizer:
|
||
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
||
|
if tokenizer_path is None:
|
||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||
|
self.max_length = max_length
|
||
|
self.min_length = min_length
|
||
|
|
||
|
empty = self.tokenizer('')["input_ids"]
|
||
|
if has_start_token:
|
||
|
self.tokens_start = 1
|
||
|
self.start_token = empty[0]
|
||
|
self.end_token = empty[1]
|
||
|
else:
|
||
|
self.tokens_start = 0
|
||
|
self.start_token = None
|
||
|
self.end_token = empty[0]
|
||
|
|
||
|
if pad_token is not None:
|
||
|
self.pad_token = pad_token
|
||
|
elif pad_with_end:
|
||
|
self.pad_token = self.end_token
|
||
|
else:
|
||
|
self.pad_token = 0
|
||
|
|
||
|
self.pad_with_end = pad_with_end
|
||
|
self.pad_to_max_length = pad_to_max_length
|
||
|
|
||
|
vocab = self.tokenizer.get_vocab()
|
||
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||
|
self.embedding_directory = embedding_directory
|
||
|
self.max_word_length = 8
|
||
|
self.embedding_identifier = "embedding:"
|
||
|
self.embedding_size = embedding_size
|
||
|
self.embedding_key = embedding_key
|
||
|
|
||
|
def _try_get_embedding(self, embedding_name:str):
|
||
|
'''
|
||
|
Takes a potential embedding name and tries to retrieve it.
|
||
|
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||
|
'''
|
||
|
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||
|
if embed is None:
|
||
|
stripped = embedding_name.strip(',')
|
||
|
if len(stripped) < len(embedding_name):
|
||
|
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||
|
return (embed, embedding_name[len(stripped):])
|
||
|
return (embed, "")
|
||
|
|
||
|
|
||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||
|
'''
|
||
|
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||
|
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||
|
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||
|
Returned list has the dimensions NxM where M is the input size of CLIP
|
||
|
'''
|
||
|
|
||
|
text = escape_important(text)
|
||
|
parsed_weights = token_weights(text, 1.0)
|
||
|
|
||
|
#tokenize words
|
||
|
tokens = []
|
||
|
for weighted_segment, weight in parsed_weights:
|
||
|
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
||
|
to_tokenize = [x for x in to_tokenize if x != ""]
|
||
|
for word in to_tokenize:
|
||
|
#if we find an embedding, deal with the embedding
|
||
|
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||
|
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||
|
embed, leftover = self._try_get_embedding(embedding_name)
|
||
|
if embed is None:
|
||
|
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||
|
else:
|
||
|
if len(embed.shape) == 1:
|
||
|
tokens.append([(embed, weight)])
|
||
|
else:
|
||
|
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
|
||
|
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
|
||
|
if leftover != "":
|
||
|
word = leftover
|
||
|
else:
|
||
|
continue
|
||
|
#parse word
|
||
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
||
|
|
||
|
#reshape token array to CLIP input size
|
||
|
batched_tokens = []
|
||
|
batch = []
|
||
|
if self.start_token is not None:
|
||
|
batch.append((self.start_token, 1.0, 0))
|
||
|
batched_tokens.append(batch)
|
||
|
for i, t_group in enumerate(tokens):
|
||
|
#determine if we're going to try and keep the tokens in a single batch
|
||
|
is_large = len(t_group) >= self.max_word_length
|
||
|
|
||
|
while len(t_group) > 0:
|
||
|
if len(t_group) + len(batch) > self.max_length - 1:
|
||
|
remaining_length = self.max_length - len(batch) - 1
|
||
|
#break word in two and add end token
|
||
|
if is_large:
|
||
|
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||
|
batch.append((self.end_token, 1.0, 0))
|
||
|
t_group = t_group[remaining_length:]
|
||
|
#add end token and pad
|
||
|
else:
|
||
|
batch.append((self.end_token, 1.0, 0))
|
||
|
if self.pad_to_max_length:
|
||
|
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||
|
#start new batch
|
||
|
batch = []
|
||
|
if self.start_token is not None:
|
||
|
batch.append((self.start_token, 1.0, 0))
|
||
|
batched_tokens.append(batch)
|
||
|
else:
|
||
|
batch.extend([(t,w,i+1) for t,w in t_group])
|
||
|
t_group = []
|
||
|
|
||
|
#fill last batch
|
||
|
batch.append((self.end_token, 1.0, 0))
|
||
|
if self.pad_to_max_length:
|
||
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||
|
if self.min_length is not None and len(batch) < self.min_length:
|
||
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||
|
|
||
|
if not return_word_ids:
|
||
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||
|
|
||
|
return batched_tokens
|
||
|
|
||
|
|
||
|
def untokenize(self, token_weight_pair):
|
||
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||
|
|
||
|
def state_dict(self):
|
||
|
return {}
|
||
|
|
||
|
class SD1Tokenizer:
|
||
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||
|
self.clip_name = clip_name
|
||
|
self.clip = "clip_{}".format(self.clip_name)
|
||
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||
|
|
||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||
|
out = {}
|
||
|
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||
|
return out
|
||
|
|
||
|
def untokenize(self, token_weight_pair):
|
||
|
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||
|
|
||
|
def state_dict(self):
|
||
|
return {}
|
||
|
|
||
|
class SD1CheckpointClipModel(SDClipModel):
|
||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||
|
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||
|
|
||
|
class SD1ClipModel(torch.nn.Module):
|
||
|
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
|
||
|
super().__init__()
|
||
|
|
||
|
if name is not None:
|
||
|
self.clip_name = name
|
||
|
self.clip = "{}".format(self.clip_name)
|
||
|
else:
|
||
|
self.clip_name = clip_name
|
||
|
self.clip = "clip_{}".format(self.clip_name)
|
||
|
|
||
|
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||
|
|
||
|
self.dtypes = set()
|
||
|
if dtype is not None:
|
||
|
self.dtypes.add(dtype)
|
||
|
|
||
|
def set_clip_options(self, options):
|
||
|
getattr(self, self.clip).set_clip_options(options)
|
||
|
|
||
|
def reset_clip_options(self):
|
||
|
getattr(self, self.clip).reset_clip_options()
|
||
|
|
||
|
def encode_token_weights(self, token_weight_pairs):
|
||
|
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||
|
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||
|
return out
|
||
|
|
||
|
def load_sd(self, sd):
|
||
|
return getattr(self, self.clip).load_sd(sd)
|