-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsubmitfile.py
More file actions
108 lines (87 loc) · 4.46 KB
/
submitfile.py
File metadata and controls
108 lines (87 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import pandas as pd
from transformers import AutoTokenizer,AutoModelForCausalLM
import os
def unlearning(model_path,output_path,forget_path,retain_path):
tokenizer = AutoTokenizer.from_pretrained(model_path)
class UnlearningDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer,data):
# Load the appropriate tokenizer
self.tokenizer = tokenizer
# Tokenize the input and output with padding and truncation
self.data=data
def __len__(self):
return len(self.data)
#Prompt + Answer
#Prompt + Answer
def __getitem__(self, index):
prompt = self.tokenizer(self.data.iloc[index]["input"],padding="max_length",truncation=True, max_length=512, return_tensors=None)
labels=self.tokenizer(f"{self.data.iloc[index]['input']} {self.data.iloc[index]['output']}",padding="max_length",truncation=True, max_length=512, return_tensors=None)
attention_mask = prompt["attention_mask"]
start_locs=self.tokenizer(self.data.iloc[index]["input"])
return {
"input_ids": torch.tensor(prompt["input_ids"]),
"attention_mask": torch.tensor(attention_mask),
"start_locs":len(start_locs["input_ids"])-1,
"labels": torch.tensor(labels["input_ids"]),
"split":1 if self.data.iloc[index]["split"]=="forget" else 0,
}
#Preparing data
retain_df = pd.read_parquet(retain_path, engine='pyarrow')
forget_df = pd.read_parquet(forget_path, engine='pyarrow')
train_data=pd.concat([retain_df,forget_df],ignore_index=True)
dataset=UnlearningDataset(tokenizer,train_data)
dataloader=torch.utils.data.DataLoader(dataset,batch_size=4,shuffle=True)
unlearn_model=AutoModelForCausalLM.from_pretrained(model_path)
good_teacher=AutoModelForCausalLM.from_pretrained(model_path)
optimizer=torch.optim.SGD(unlearn_model.parameters(),lr=0.001)
device="cuda" if torch.cuda.is_available() else "cpu"
def kl_divergence(current_model,good_teacher,batch, device):
normal_outputs = current_model(
batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
labels=batch["labels"].to(device),
)
with torch.no_grad():
good_teacher_outputs = good_teacher(
batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
labels=batch["labels"].to(device),
)
# P: pretrained model; Q: current model.
l=torch.unsqueeze(batch["split"],-1)
l=torch.unsqueeze(l,-1)
bad_teacher=torch.normal(mean = 0,
std = 1,
size = good_teacher_outputs.logits.shape).cuda() + torch.ones(good_teacher_outputs.logits.shape[-1]).cuda()
prob_p = torch.nn.functional.softmax(bad_teacher.to(device), -1)
prob_f = torch.nn.functional.softmax(good_teacher_outputs.logits, -1)
prob_q = torch.nn.functional.softmax(normal_outputs.logits, -1)
out_teacher= (1-l.to(device))*prob_f + l.to(device)*prob_p
loss = (out_teacher * (torch.log(out_teacher + 1e-12) - torch.log(prob_q + 1e-12))).sum(-1).mean()
return loss
unlearn_model.to(device)
good_teacher.to(device)
unlearn_model.train()
good_teacher.eval()
for forget_epoch in range(2):
for batch in dataloader:
optimizer.zero_grad()
loss=kl_divergence(unlearn_model,good_teacher,batch,device)
print(f"Batch Loss:{loss.item()}")
loss.backward()
optimizer.step()
print("First Epoch is finsihed")
unlearn_model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
def main():
path = "/data1/malto/unlearning_llm/"
model_path = path + 'models/semeval25-unlearning-model'
dataset_path = path + 'datasets/semeval25-unlearning-data/'
retain_path=dataset_path+'data/retain_train-00000-of-00001.parquet'
forget_path=dataset_path+'data/forget_train-00000-of-00001.parquet'
model_path="/home/amunis/Unlearning-sensitive-content-from-LLMs/preunleran_1b"
output_path='/home/amunis/Unlearning-sensitive-content-from-LLMs/submit_test_10e3'
unlearning(model_path,output_path,forget_path,retain_path)
if __name__=="__main__":
main()