dropout():参数“输入”(位置1)必须是张量,而不是在使用Bert和Huggingface时的str
我的代码运行良好,当我今天尝试在不更改任何内容的情况下运行它时,出现以下错误:
dropout(): 参数“输入”(位置 1)必须是张量,而不是 str
如果可以提供帮助,将不胜感激。可能是数据加载器的问题?
回答
如果您使用 HuggingFace,此信息可能会很有用。我有同样的错误,并在dropout之前在模型类中添加参数return_dict=False修复它:输出 = 模型(**输入,return_dict = False)
回答
我也在研究同一个 repo。有一个可能名为 Bert_Arch 的类继承了 nn.Module,这个类有一个名为 forward 的重写方法。在 forward 方法中,只需将参数 'return_dict=False' 添加到 self.bert() 方法调用中。代替
_, cls_hs = self.bert(sent_id, attention_mask=mask)
和
_, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)
THE END
二维码