forked from DominicVinxander/DSFNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlocal_train.py
More file actions
44 lines (32 loc) · 1.35 KB
/
local_train.py
File metadata and controls
44 lines (32 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# ===========================
# -*- coding: utf-8 -*-
# Author: DSFNet Authors
# Date: 2024-09-27
# Keep Coding, Keep Thinking
# ===========================
import numpy as np
np.set_printoptions(threshold=np.inf)
import pandas as pd
pd.set_option('display.max_columns', None)
from corona_train.core.context import GraphContext
from ucore.config import dsfnet_configer
from ucore.config.column_config import train_columns
dsfnet_configer.read_epochs = 1
dsfnet_configer.batch_size = 52
dsfnet_configer.STEP_PRINT_HEART = 50
dsfnet_configer.STEP_PRINT_EVAL = 50
dsfnet_configer.DATA_SET_OSS = False
dsfnet_configer.LOCAL_TRAIN = True
graph = GraphContext()
print("\n--------------- read data ----------------")
train_set = graph.source_op(table_name="dsfnet_paper_train_data_www_2025", columns=train_columns,
limit_num=500)
valid_set = graph.source_op(table_name="dsfnet_paper_test_data_www_2025", columns=train_columns,
limit_num=500)
print("\n--------------- transform ----------------")
train_set_ori, valid_set_ori = graph.table_transform_op(train_set, valid_set)
print("\n--------------- unpack ----------------")
train_set, valid_set = graph.unpack_op(train_set_ori, valid_set_ori)
print("\n--------------- start training ----------------")
_, ckp_path = graph.train_op(train_set, valid_set)
graph.save()