-
Notifications
You must be signed in to change notification settings - Fork 3
Description
When I used code like that (want to lora train the unet part as well):
train_unet_lora = add_lora_to( unet, target_module=unet_replace, search_class=[torch.nn.Linear], r=args.lora_rank, lora_bias=args.lora_bias )
But it raise error:
File "train_lora.py", line 624, in <module> main(args) File "train_lora.py", line 522, in main model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 817, in forward return model_forward(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 805, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast return func(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 1121, in forward sample, res_samples = downsample_block( File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 1198, in forward hidden_states = resnet(hidden_states, temb, scale=lora_scale) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/resnet.py", line 373, in forward self.time_emb_proj(temb, scale)[:, :, None, None] File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given