Skip to content
70 changes: 37 additions & 33 deletions extract_feature_print.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, sys, traceback
import tqdm

# device=sys.argv[1]
n_part = int(sys.argv[2])
Expand Down Expand Up @@ -87,37 +88,40 @@ def readwave(wav_path, normalize=False):
printt("no-feature-todo")
else:
printt("all-feature-%s" % len(todo))
for idx, file in enumerate(todo):
try:
if file.endswith(".wav"):
wav_path = "%s/%s" % (wavPath, file)
out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))

if os.path.exists(out_path):
continue

feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.half().to(device)
if device not in ["mps", "cpu"]
else feats.to(device),
"padding_mask": padding_mask.to(device),
"output_layer": 9 if version == "v1" else 12, # layer 9
}
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = (
model.final_proj(logits[0]) if version == "v1" else logits[0]
)

feats = feats.squeeze(0).float().cpu().numpy()
if np.isnan(feats).sum() == 0:
np.save(out_path, feats, allow_pickle=False)
else:
printt("%s-contains nan" % file)
if idx % n == 0:
printt("now-%s,all-%s,%s,%s" % (idx, len(todo), file, feats.shape))
except:
printt(traceback.format_exc())
with tqdm.tqdm(total=len(todo)) as pbar:
for idx, file in enumerate(todo):
try:
if file.endswith(".wav"):
wav_path = "%s/%s" % (wavPath, file)
out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))

if os.path.exists(out_path):
continue

feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.half().to(device)
if device not in ["mps", "cpu"]
else feats.to(device),
"padding_mask": padding_mask.to(device),
"output_layer": 9 if version == "v1" else 12, # layer 9
}
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = (
model.final_proj(logits[0]) if version == "v1" else logits[0]
)

feats = feats.squeeze(0).float().cpu().numpy()
if np.isnan(feats).sum() == 0:
np.save(out_path, feats, allow_pickle=False)
else:
printt("%s-contains nan" % file)
# if idx % n == 0:
# printt("now-%s,all-%s,%s,%s" % (idx, len(todo), file, feats.shape))
pbar.set_description("file %s, shape %s" % (file, feats.shape))
except:
printt(traceback.format_exc())
pbar.update(1)
printt("all-feature-done")
96 changes: 35 additions & 61 deletions train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ def go(model, bkey):
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
"shape-%s-mismatch|need-%s|get-%s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
print("shape-%s-mismatch|need-%s|get-%s" % (k, state_dict[k].shape, saved_state_dict[k].shape)) #
raise KeyError
except:
# logger.info(traceback.format_exc())
Expand All @@ -52,9 +49,7 @@ def go(model, bkey):

iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None and load_opt == 1
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
if optimizer is not None and load_opt == 1: ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
# try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
# except:
Expand Down Expand Up @@ -106,10 +101,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
"shape-%s-mismatch|need-%s|get-%s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
print("shape-%s-mismatch|need-%s|get-%s" % (k, state_dict[k].shape, saved_state_dict[k].shape)) #
raise KeyError
except:
# logger.info(traceback.format_exc())
Expand All @@ -123,9 +115,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):

iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None and load_opt == 1
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
if optimizer is not None and load_opt == 1: ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
# try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
# except:
Expand All @@ -134,33 +124,39 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
return model, optimizer, learning_rate, iteration


def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
iteration, checkpoint_path
)
)
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, checkpoint_type, delete_old=False):
# logger.info(
# "Saving model and optimizer state at epoch {} to {}".format(
# iteration, checkpoint_path
# )
# )
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
if delete_old:
latest_checkpoint = latest_checkpoint_path(checkpoint_path, regex=("G_*.pth" if checkpoint_type.startswith("G") else "D_*.pth"))

torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
os.path.join(checkpoint_path, checkpoint_type),
)
# delete after saving new checkpoint to avoid loss if save fails
if delete_old and latest_checkpoint is not None:
os.remove(latest_checkpoint)


def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
iteration, checkpoint_path
)
)
# logger.info(
# "Saving model and optimizer state at epoch {} to {}".format(
# iteration, checkpoint_path
# )

if hasattr(combd, "module"):
state_dict_combd = combd.module.state_dict()
else:
Expand Down Expand Up @@ -203,8 +199,10 @@ def summarize(
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
if len(f_list) == 0:
return None
x = f_list[-1]
print(x)
# print(x)
return x


Expand Down Expand Up @@ -247,9 +245,7 @@ def plot_alignment_to_numpy(alignment, info=None):
import numpy as np

fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
im = ax.imshow(alignment.transpose(), aspect="auto", origin="lower", interpolation="none")
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
Expand Down Expand Up @@ -302,35 +298,21 @@ def get_hparams(init=True):
required=True,
help="checkpoint save frequency (epoch)",
)
parser.add_argument(
"-te", "--total_epoch", type=int, required=True, help="total_epoch"
)
parser.add_argument(
"-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
)
parser.add_argument(
"-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
)
parser.add_argument("-te", "--total_epoch", type=int, required=True, help="total_epoch")
parser.add_argument("-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path")
parser.add_argument("-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path")
parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
parser.add_argument(
"-bs", "--batch_size", type=int, required=True, help="batch size"
)
parser.add_argument(
"-e", "--experiment_dir", type=str, required=True, help="experiment dir"
) # -m
parser.add_argument(
"-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
)
parser.add_argument("-bs", "--batch_size", type=int, required=True, help="batch size")
parser.add_argument("-e", "--experiment_dir", type=str, required=True, help="experiment dir") # -m
parser.add_argument("-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k")
parser.add_argument(
"-sw",
"--save_every_weights",
type=str,
default="0",
help="save the extracted model in weights directory when saving checkpoints",
)
parser.add_argument(
"-v", "--version", type=str, required=True, help="model version"
)
parser.add_argument("-v", "--version", type=str, required=True, help="model version")
parser.add_argument(
"-f0",
"--if_f0",
Expand Down Expand Up @@ -417,11 +399,7 @@ def get_hparams_from_file(config_path):
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(source_dir))
return

cur_hash = subprocess.getoutput("git rev-parse HEAD")
Expand All @@ -430,11 +408,7 @@ def check_git_hash(model_dir):
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
logger.warn("git hash values are different. {}(saved) != {}(current)".format(saved_hash[:8], cur_hash[:8]))
else:
open(path, "w").write(cur_hash)

Expand Down
Loading