diff --git a/datasets/__init__.py b/datasets/__init__.py index 31ae9f3..e3f3cbd 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -1,8 +1,8 @@ train_dict = { - "corps_MMD_01": [ - "Uncle", - "Uncle", - ], + # "corps_MMD_01": [ + # "Uncle", + # "Uncle", + # ], "corps_BerkeleyMHAD_01": [ #"BerkeleyMHAD_skl_s01", #"BerkeleyMHAD_skl_s02", diff --git a/demo b/demo index 06f9531..42f01f4 100644 --- a/demo +++ b/demo @@ -23,7 +23,7 @@ def get_height(file): def example( - src_name, src_bvh, dst_name, dst_bvh, output_path, convert_type + src_name, src_bvh, ref_name, ref_bvh, output_path, convert_type ): if not os.path.exists(output_path): os.makedirs(output_path) @@ -31,14 +31,14 @@ def example( input_file = f"./datasets/Motions/{src_name}/{src_bvh}" copy_ref_file(input_file, pjoin(output_path, "input.bvh")) - ref_file = f'./datasets/Motions/{dst_name}/{dst_bvh}' + ref_file = f'./datasets/Motions/{ref_name}/{ref_bvh}' copy_ref_file(ref_file, pjoin(output_path, 'reference.bvh')) height = get_height(input_file) input_file = f"./datasets/Motions/{src_name}/{src_bvh}" - ref_file = f"./datasets/Motions/{dst_name}/{dst_bvh}" + ref_file = f"./datasets/Motions/{ref_name}/{ref_bvh}" joined_output_path = pjoin(output_path, "result.bvh") - cmd = f"python eval_single_pair.py --input_bvh=\"{input_file}\" --target_bvh=\"{ref_file}\" --output_filename=\"{joined_output_path}\" --convert_type=\"{convert_type}\"" + cmd = f"python eval_single_pair.py --input_bvh=\"{input_file}\" --ref_bvh=\"{ref_file}\" --output_filename=\"{joined_output_path}\" --convert_type=\"{convert_type}\"" err = os.system(cmd) if err: @@ -54,24 +54,24 @@ def example( if __name__ == "__main__": src_name = "Kaya" src_bvh = "Golf Putt Victory.bvh" - tgt_name = "BerkeleyMHAD_skl_s04" - tgt_bvh = "skl_s04_a01_r01.bvh" + ref_name = "BerkeleyMHAD_skl_s04" + ref_bvh = "skl_s04_a01_r01.bvh" example( src_name=src_name, src_bvh=src_bvh, - ref_name=dst_name, - ref_bvh=dst_bvh, - output_path=f"./examples/{dst_name}_{src_bvh}/", + ref_name=ref_name, + ref_bvh=ref_bvh, + output_path=f"./examples/{ref_name}_{src_bvh}/", convert_type="cross", ) - src_name, dst_name = dst_name, src_name - src_bvh, dst_bvh = dst_bvh, src_bvh + src_name, ref_name = ref_name, src_name + src_bvh, ref_bvh = ref_bvh, src_bvh example( src_name=src_name, src_bvh=src_bvh, - ref_name=dst_name, - ref_bvh=dst_bvh, - output_path=f"./examples/{dst_name}_{src_bvh}/", + ref_name=ref_name, + ref_bvh=ref_bvh, + output_path=f"./examples/{ref_name}_{src_bvh}/", convert_type="cross", ) diff --git a/eval_single_pair.py b/eval_single_pair.py index 850968b..34f0d4d 100644 --- a/eval_single_pair.py +++ b/eval_single_pair.py @@ -13,7 +13,7 @@ def main(): parser = option_parser.get_parser() parser.add_argument("--input_bvh", type=str, required=True) - parser.add_argument("--target_bvh", type=str, required=False) + parser.add_argument("--ref_bvh", type=str, required=False) parser.add_argument("--output_filename", type=str, required=True) parser.add_argument("--cpu", type=bool, required=False) parser.add_argument("--convert_type", type=str, required=True) @@ -21,12 +21,12 @@ def main(): args = parser.parse_args() input_bvh = args.input_bvh - target_bvh = args.target_bvh + ref_bvh = args.ref_bvh cpu = args.cpu _convert_type = args.convert_type source_character = input_bvh.split("/")[-2] - target_character = target_bvh.split("/")[-2] + target_character = ref_bvh.split("/")[-2] print(f'Source character {source_character}') print(f'Target character {target_character}') character_names = [] @@ -41,14 +41,14 @@ def main(): if topo_index % 2 == 0: character_names.append([target_character]) - file_id.append([target_bvh]) + file_id.append([ref_bvh]) character_names.append([source_character]) file_id.append([input_bvh]) else: character_names.append([source_character]) file_id.append([input_bvh]) character_names.append([target_character]) - file_id.append([target_bvh]) + file_id.append([ref_bvh]) character_names = [character_names] print(f'Character names {character_names}') @@ -73,7 +73,7 @@ def main(): print(f'Characters {character_names}') dataset = create_dataset(args, character_names) model = create_model(args, character_names, dataset, topologies) - model.load(epoch=1900, topology=topo_index) + model.load(epoch=50, topology=topo_index) input_motion = [] if not os.path.exists(input_bvh):