Skip to content

Huggingface 示例代码bug #179

@LiuZhihhxx

Description

@LiuZhihhxx

运行huggingface关于AutoformerForPrediction的演示代码

from huggingface_hub import hf_hub_download
import torch
from transformers import AutoformerForPrediction

file = hf_hub_download(
    repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
)
batch = torch.load(file)

model = AutoformerForPrediction.from_pretrained("huggingface/autoformer-tourism-monthly")

# during training, one provides both past and future values
# as well as possible additional features
outputs = model(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    past_observed_mask=batch["past_observed_mask"],
    static_categorical_features=batch["static_categorical_features"],
    static_real_features=batch["static_real_features"],
    future_values=batch["future_values"],
    future_time_features=batch["future_time_features"],
)

loss = outputs.loss
loss.backward()

# during inference, one only provides past values
# as well as possible additional features
# the model autoregressively generates future values
outputs = model.generate(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    past_observed_mask=batch["past_observed_mask"],
    static_categorical_features=batch["static_categorical_features"],
    static_real_features=batch["static_real_features"],
    future_time_features=batch["future_time_features"],
)

mean_prediction = outputs.sequences.mean(dim=1)

outputs = model(...)出现了矩阵维度不匹配的bug:
RuntimeError: mat 1 and mat 2 shapes cannot be multiplied(1536x23 and 22x64)

对应数据集中,bs=64, 输入长度=61, 预测长度=24, 有两个时间特征. 本人能力有限只能看出来1536=64*24, 其他几个维度实在是找不到规律所在. 而在前面AutoformerModel的demo与之相似,但在outputs = model(...)这步却没有报错. 请问应该如何解决? 感激不尽!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions