Skip to content

Commit 42e1e3c

Browse files
Update usage doc regarding generate fn (#3504)
1 parent 57b0fab commit 42e1e3c

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

docs/source/usage.rst

+4-8
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ to generate the tokens following the initial sequence in PyTorch, and creating a
420420
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
421421

422422
input = tokenizer.encode(sequence, return_tensors="pt")
423-
generated = model.generate(input, max_length=50)
423+
generated = model.generate(input, max_length=50, do_sample=True)
424424

425425
resulting_string = tokenizer.decode(generated.tolist()[0])
426426
print(resulting_string)
@@ -432,14 +432,10 @@ to generate the tokens following the initial sequence in PyTorch, and creating a
432432
model = TFAutoModelWithLMHead.from_pretrained("gpt2")
433433

434434
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
435-
generated = tokenizer.encode(sequence)
436-
437-
for i in range(50):
438-
predictions = model(tf.constant([generated]))[0]
439-
token = tf.argmax(predictions[0], axis=1)[-1].numpy()
440-
generated += [token]
435+
input = tokenizer.encode(sequence, return_tensors="tf")
436+
generated = model.generate(input, max_length=50, do_sample=True)
441437

442-
resulting_string = tokenizer.decode(generated)
438+
resulting_string = tokenizer.decode(generated.tolist()[0])
443439
print(resulting_string)
444440

445441

0 commit comments

Comments
 (0)