-
Notifications
You must be signed in to change notification settings - Fork 70
Description
@lucidrains
This is a issue I'm having a while, the cross-attention is very weak at the start of the sequence.
When the transformer starts with no tokens it will relay on the cross-attention but unfortunately the cross-attention doesn't work for the first token(s).
Proof
To prove this I trained a dataset of 500 models that have unique text embeddings and no augmentations, then I only took the first 6 tokens of the mesh and train on that.
After training for 8hrs, it's still stuck at 1.03 loss.
Without fixing this issue, the auto-regression without a prompt of tokens will never work.
This problem has been ongoing for a while but I thought it was a issue of training and using a model that has been trained on the first few tokens would resolve this. However that isn't the case.
Real-life example
To highlight the issue, I trained a model on the 13k dataset then removed all the augmentation copies and removed models with duplicate labels.
If I provide it with the first 2 tokens as a prompt it will autocomplete without no problem and no visual issues, however if i provide it with 1 or 0 tokens it fails completely.
Checked the logits
I investigated this further and checked the logits when it generated the first token, the probability for correct token was at the 9th most probable token.
I tried to implement a beam search with beam width of 5 but since the first token has such a low probability, it would require a lot of beams which probably will work but this seems like a brute force solution isn't very good.
It may work to do a beam search of 20 and then kill of the solutions which seems to have a low probability/entropy, but this seems like a bandage solution that might not work with scaling up meshgpt.
Why is this a problem?
The first tokens are very important for the generation since it's a domino effect, if it gets the incorrect token at the start, the generation will fail since it relays to much on the sequence to auto-correct.
It's like if the sentence is "Dog can be a happy animal" and it predicts "Human" as the first token, it won't be able to auto-correct since sentence is already messed up and the chances it will auto-correct to "Human got a dog which can be a happy animal" is extremely hard.
Possible solution
Since the cross-attention is used only on the "big" decoder, can it also be implemented for the fine decoder?
Attempts to fix:
- I've tried removing the fine decoder and fine gateloop
- I also tried increasing cross_attn_num_mem_kv but found no signifiant changes.
- I replaced theTextEmbeddingReturner with AttentionTextConditioner but still no changes.
- Tried using different text encoder such as BGE and CLIP.
This has been a problem for a long time and I've mentioned in the issues threads as a note so I'm creating a issue for it since it really prevents me from releasing fine-tuned models.
I got a model ready to go that can predict 13k models but since the first tokens make the autoregressive generation makes it impossible, I've not released it yet.