Skip to content

Loss of accuracy when Longformer for SequenceClassification model is exported to ONNX #776

@SteffenHaeussler

Description

@SteffenHaeussler

Edit: This is a crosspost to pytorch #94810. I don't know, where the issue lies.

System info

  • transformers version: 4.26.1
  • Platform: macOS-10.16-x86_64-i386-64bit
  • Python version: 3.9.12
  • PyTorch version (GPU?): 1.13.0 (False)
  • onnx: 1.13.0
  • onnxruntime: 1.13.1

Who can help?

I think
@younesbelkada
would be a great help :)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

This model is trained on client data and I'm not allowed to share the data or the weights, which makes any reproduction of this issue much harder. Please let me know when you need more information.

Here is the code snippet for the onnx conversion:

I follow this tutorial, but I also tried your tutorial. The onnx conversion with optimum is not available for Longformer so far and I haven't figured out yet, how to add it.

conversion:

import numpy as np
from onnxruntime import InferenceSession
from tqdm.auto import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("deployment/best_model/")
model = AutoModelForSequenceClassification.from_pretrained("deployment/best_model/")

model.to("cpu")
model.eval()

example_input = tokenizer(
    dataset["test"]["text"][0], max_length=512, truncation=True, return_tensors="pt"
)
_ = model(**example_input)

torch.onnx.export(
    model,
    tuple(example_input.values()),
    f="model.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence"},
        "attention_mask": {0: "batch_size", 1: "sequence"},
        "logits": {0: "batch_size", 1: "sequence"},
    },
    do_constant_folding=True,
    opset_version=16,
)

Calculating the accuracy:

session = InferenceSession("deployment/model.onnx", providers=["CPUExecutionProvider"])

y_hat_torch = []
y_hat_onnx = []

for text in dataset["test"]["text"]:
    tok_text = tokenizer(
        text, padding="max_length", max_length=512, truncation=True, return_tensors="np"
    )
    pred = session.run(None, input_feed=dict(tok_text))
    pred = np.argsort(pred[0][0])[::-1][0]
    y_hat_onnx.append(int(pred))

    tok_text = tokenizer(
        text, padding="max_length", max_length=512, truncation=True, return_tensors="pt"
    )
    pred = model(**tok_text)
    pred = torch.argsort(pred[0][0], descending=True)[0].numpy()
    y_hat_torch.append(int(pred))

print(
    f"Accuracy onnx:{sum([int(i)== int(j) for I, j in zip(y_hat_onnx, dataset['test']['label'])]) / len(y_hat_onnx):.2f}"
)
print(
    f"Accuracy torch:{sum([int(i)== int(j) for I, j in zip(y_hat_torch, dataset['test']['label'])]) / len(y_hat_torch):.2f}"
)

I also looked into the models' weights and the weights for the attention layer differ between torch and onnx. Here is an example:

import torch
import onnx
from onnx import numpy_helper

import numpy as np
from numpy.testing import assert_almost_equal

from transformers import AutoTokenizer, AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("deployment/best_model/")
onnx_model = onnx.load("deployment/model.onnx")

graph = onnx_model.graph

initalizers = dict()
for init in graph.initializer:
    initalizers[init.name] = numpy_helper.to_array(init).astype(np.float16)

model_init = dict()
for name, p in model.named_parameters():
    model_init[name] = p.detach().numpy().astype(np.float16)

assert len(initalizers) == len(model_init.keys()) # 53 layers

assert_almost_equal(initalizers['longformer.embeddings.word_embeddings.weight'], 
                    model_init['longformer.embeddings.word_embeddings.weight'], decimal=5)

assert_almost_equal(initalizers['classifier.dense.weight'], 
                    model_init['classifier.dense.weight'], decimal=5)

For the layer longformer.encoder.layer.0.output.dense.weight, which aligns with onnx::MatMul_6692 in shape and position:

assert_almost_equal(initalizers['onnx::MatMul_6692'], 
                    model_init['longformer.encoder.layer.0.output.dense.weight'], decimal=4)

I get

AssertionError: 
Arrays are not almost equal to 4 decimals

Mismatched elements: 2356293 / 2359296 (99.9%)
Max absolute difference: 1.776
Max relative difference: inf
 x: array([[ 0.0106,  0.1076,  0.0801, ...,  0.0425,  0.1548,  0.0123],
       [-0.0399, -0.1415,  0.0916, ...,  0.0181, -0.1277, -0.1335],
       [-0.0961,  0.0013,  0.0558, ..., -0.1354, -0.0965,  0.0447],...
 y: array([[-0.0699,  0.0743,  0.0339, ...,  0.0564, -0.087 ,  0.0649],
       [-0.1315, -0.0967, -0.045 , ..., -0.0492,  0.0775,  0.0284],
       [-0.1094,  0.0364,  0.1263, ..., -0.0308, -0.0118,  0.1523],...

Model config:

{
  "_name_or_path": "/datadrive/model/onnx/",
  "architectures": [
    "LongformerForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    512,
    512
  ],
  "bos_token_id": 1,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 514,
  "model_type": "longformer",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "onnx_export": false,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "problem_type": "single_label_classification",
  "sep_token_id": 2,
  "torch_dtype": "float32",
  "transformers_version": "4.26.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32768
}

Expected behavior

I would expect a similar accuracy for both models:

Accuracy onnx: 17 %
Accuracy torch: 70 %

on test data with 3800 samples.

I would like to know what went wrong, how I can fix it, or who can help me. I'm clueless at the moment.
Alternatively I can also move to BigBird architecture since it has already some implementation on optimum.

I trained a small Longformer language model from scratch and fine-tuned it with custom data on a Sequence classification head. I used fp16 for training. The training run on a gpu.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions