Skip to content

[Critical] Very high loss rate at first few tokens (classifier free guidance not working) #80

@MarcusLoppe

Description

@MarcusLoppe

@lucidrains
This is a issue I'm having a while, the cross-attention is very weak at the start of the sequence.
When the transformer starts with no tokens it will relay on the cross-attention but unfortunately the cross-attention doesn't work for the first token(s).

Proof

To prove this I trained a dataset of 500 models that have unique text embeddings and no augmentations, then I only took the first 6 tokens of the mesh and train on that.
After training for 8hrs, it's still stuck at 1.03 loss.

Without fixing this issue, the auto-regression without a prompt of tokens will never work.

This problem has been ongoing for a while but I thought it was a issue of training and using a model that has been trained on the first few tokens would resolve this. However that isn't the case.
Real-life example
To highlight the issue, I trained a model on the 13k dataset then removed all the augmentation copies and removed models with duplicate labels.
If I provide it with the first 2 tokens as a prompt it will autocomplete without no problem and no visual issues, however if i provide it with 1 or 0 tokens it fails completely.

Checked the logits

I investigated this further and checked the logits when it generated the first token, the probability for correct token was at the 9th most probable token.
I tried to implement a beam search with beam width of 5 but since the first token has such a low probability, it would require a lot of beams which probably will work but this seems like a brute force solution isn't very good.
It may work to do a beam search of 20 and then kill of the solutions which seems to have a low probability/entropy, but this seems like a bandage solution that might not work with scaling up meshgpt.

Why is this a problem?

The first tokens are very important for the generation since it's a domino effect, if it gets the incorrect token at the start, the generation will fail since it relays to much on the sequence to auto-correct.
It's like if the sentence is "Dog can be a happy animal" and it predicts "Human" as the first token, it won't be able to auto-correct since sentence is already messed up and the chances it will auto-correct to "Human got a dog which can be a happy animal" is extremely hard.

Possible solution

Since the cross-attention is used only on the "big" decoder, can it also be implemented for the fine decoder?

Attempts to fix:

  • I've tried removing the fine decoder and fine gateloop
  • I also tried increasing cross_attn_num_mem_kv but found no signifiant changes.
  • I replaced theTextEmbeddingReturner with AttentionTextConditioner but still no changes.
  • Tried using different text encoder such as BGE and CLIP.

This has been a problem for a long time and I've mentioned in the issues threads as a note so I'm creating a issue for it since it really prevents me from releasing fine-tuned models.

I got a model ready to go that can predict 13k models but since the first tokens make the autoregressive generation makes it impossible, I've not released it yet.

Here is some images over the loss:
bild

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions