-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_usage.py
More file actions
72 lines (55 loc) · 2.47 KB
/
example_usage.py
File metadata and controls
72 lines (55 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
"""
Example script demonstrating how to use the TinyLLM inference system.
This script shows different ways to load and use the trained model.
"""
from tools.inference import ModelInference, load_model
def main():
print("=== TinyLLM Inference Example ===\n")
# Method 1: Using the convenience function
# print("1. Loading model using convenience function...")
# model = load_model()
# Method 2: Using the class directly
print("1. Loading model manually...")
model = ModelInference(checkpoint_path="checkpoints/last_model.pt", device="auto")
# Test prompts
prompts = [
"The future of artificial intelligence is",
"Once upon a time in a distant galaxy",
"In the field of machine learning,",
"The capital of France is"
]
print("\n=== Generation Examples ===")
for i, prompt in enumerate(prompts, 1):
print(f"\n--- Prompt {i}: {prompt} ---")
# Generate with different settings
print("\nSampling generation (temperature=0.8):")
generated = model.generate(prompt, max_length=60, temperature=0.8)
print(f"Result: {generated}")
print("\nGreedy generation (temperature=0.0):")
generated_greedy = model.generate(prompt, max_length=60, temperature=0.0, do_sample=False)
print(f"Result: {generated_greedy}")
print("\nTop-k sampling (k=10):")
generated_topk = model.generate(prompt, max_length=30, temperature=0.8, top_k=10, do_sample=True)
print(f"Result: {generated_topk}")
print("\n=== Next Token Probabilities ===")
test_prompt = "The weather today is"
print(f"\nAnalyzing next token probabilities for: '{test_prompt}'")
probs = model.get_next_token_probabilities(test_prompt, top_k=5)
for token, prob in probs:
print(f" Token: '{token}' | Probability: {prob:.4f}")
print("\n=== Interactive Mode ===")
print("Enter prompts to generate text (type 'quit' to exit):")
while True:
user_input = input("\nPrompt: ").strip()
if user_input.lower() == 'quit':
print("Goodbye!")
break
if user_input:
try:
generated = model.generate(user_input, max_length=80, temperature=0.8)
print(f"Generated: {generated}")
except Exception as e:
print(f"Error during generation: {e}")
if __name__ == "__main__":
main()