Skip to content

RuntimeError: mat1 and mat2 shapes cannot be multiplied (69x64 and 32x4096) in inference_server.py #33

@ahmetkca

Description

@ahmetkca
Traceback (most recent call last):
  File "/root/hertz-dev/inference_server.py", line 168, in <module>
    audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/inference_server.py", line 59, in __init__
    self.initialize_state(prompt_path)
  File "/root/hertz-dev/inference_server.py", line 80, in initialize_state
    self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/model.py", line 323, in next_audio_from_audio
    next_latents = self.next_latent(latents_in, temps)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/model.py", line 341, in next_latent
    logits = self.forward(model_input)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/model.py", line 310, in forward
    x = self.input(data)
        ^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (69x64 and 32x4096)
(ai) root@87a45c3f6c90:~/hertz-dev# nvidia-smi
Wed Dec  4 01:55:13 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0 Off |                  Off |
|  0%   26C    P8             21W /  450W |       2MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
(ai) root@87a45c3f6c90:~/hertz-dev# pip freeze
annotated-types==0.7.0
anyio==4.6.2.post1
asttokens==3.0.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.4.0
click==8.1.7
contourpy==1.3.1
cycler==0.12.1
decorator==5.1.1
einops==0.8.0
executing==2.1.0
fastapi==0.115.4
filelock==3.16.1
fonttools==4.55.1
fsspec==2024.10.0
h11==0.14.0
hf_transfer==0.1.8
huggingface-hub==0.26.2
idna==3.10
IProgress==0.4
ipython==8.18.1
jedi==0.19.2
Jinja2==3.1.4
kiwisolver==1.4.7
MarkupSafe==3.0.2
matplotlib==3.9.2
matplotlib-inline==0.1.7
mpmath==1.3.0
networkx==3.4.2
numpy==1.26.3
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
packaging==24.2
parso==0.8.4
pexpect==4.9.0
pillow==11.0.0
prompt_toolkit==3.0.48
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.3
pydantic_core==2.27.1
Pygments==2.18.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
PyYAML==6.0.2
requests==2.32.3
six==1.16.0
sniffio==1.3.1
sounddevice==0.5.1
soundfile==0.12.1
stack-data==0.6.3
starlette==0.41.3
sympy==1.13.1
torch==2.5.1
torchaudio==2.5.1
tqdm==4.66.6
traitlets==5.14.3
triton==3.1.0
typing_extensions==4.12.2
urllib3==2.2.3
uvicorn==0.32.0
wcwidth==0.2.13
websockets==13.1
(ai) root@87a45c3f6c90:~/hertz-dev# conda info

     active environment : ai
    active env location : /root/miniconda3/envs/ai
            shell level : 2
       user config file : /root/.condarc
 populated config files : /root/miniconda3/.condarc
          conda version : 24.9.2
    conda-build version : not installed
         python version : 3.12.7.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=zen3
                          __conda=24.9.2=0
                          __cuda=12.4=0
                          __glibc=2.35=0
                          __linux=6.5.0=0
                          __unix=0=0
       base environment : /root/miniconda3  (writable)
      conda av data dir : /root/miniconda3/etc/conda
  conda av metadata url : None
           channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /root/miniconda3/pkgs
                          /root/.conda/pkgs
       envs directories : /root/miniconda3/envs
                          /root/.conda/envs
               platform : linux-64
             user-agent : conda/24.9.2 requests/2.32.3 CPython/3.12.7 Linux/6.5.0-45-generic ubuntu/22.04.5 glibc/2.35 solver/libmamba conda-libmamba-solver/24.9.0 libmambapy/1.5.8 aau/0.4.4 c/. s/. e/.
                UID:GID : 0:0
             netrc file : None
           offline mode : False


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions