forked from shiyuanlsy/A2SL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_year.py
More file actions
80 lines (61 loc) · 2.67 KB
/
data_year.py
File metadata and controls
80 lines (61 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from holoviews.operation import threshold
from shapely.lib import centroid
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import train_test_split
import random
from scipy.spatial.distance import euclidean
from sklearn.cluster import KMeans
import pickle
import time
from multiprocessing import Pool, Lock
from torch.utils.data import TensorDataset, DataLoader
class LakeYearlyDataset(Dataset):
def __init__(self, data_dir):
# self.split = split
self.data_dir = data_dir
self.sample_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
self.metadata = {}
for file in self.sample_files:
parts = os.path.basename(file).split('_')
lake_id = parts[-2].replace('.csv', '')
year = int(parts[-1].replace('.csv', ''))
self.metadata[file] = {'lake_id': lake_id, 'year': year}
self.samples = self.sample_files
# Preload data
self.data = {}
self.valid_samples = []
for file in self.samples:
data = pd.read_csv(file).iloc[:, 1:-4] # get rid of year and month info in csv
if data.shape[0] != 360:
print(f"Skipping file: {file} (Row count: {data.shape[0]})")
continue
y_data_epi = data.iloc[:, [-4]] # use obs_epi
y_data_hyp = data.iloc[:, [-3]] # use obs_hyp
y_data_epi[np.isnan(y_data_epi)] = -11
y_data_hyp[np.isnan(y_data_hyp)] = -11
x_data = data.drop(data.columns[[-4, -3, -2, -1]], axis=1) # get rid of the obs_epi, obs_hyp, filled_epi, filled_hyp, only use sim_epi, sim_hyp in x
self.data[file] = {
'x': x_data,
'y_epi': y_data_epi,
'y_hyp': y_data_hyp
# 'year': year
}
# Add file to valid_samples
self.valid_samples.append(file)
self.samples = self.valid_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample_file = self.samples[idx]
month_data = self.data[sample_file]
lake_id = self.metadata[sample_file]['lake_id']
year = self.metadata[sample_file]['year']
x = torch.tensor(month_data['x'].values.astype(np.float32), dtype=torch.float32)
obs_epi = torch.tensor(month_data['y_epi'].values.astype(np.float32), dtype=torch.float32)
obs_hyp = torch.tensor(month_data['y_hyp'].values.astype(np.float32), dtype=torch.float32)
year = torch.tensor([year], dtype=torch.float32)
return x, obs_epi, obs_hyp, lake_id, year