File tree 2 files changed +85
-0
lines changed
2 files changed +85
-0
lines changed Original file line number Diff line number Diff line change
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass (frozen = True )
5
+ class BaseConfig :
6
+ version : str
7
+ model_path : str
8
+ model_s3_bucket : str
9
+ model_file_name : str
10
+
11
+ def generate_model_path (self ) -> str :
12
+ return "{}/{}/{}" .format (
13
+ self .model_path ,
14
+ self .version ,
15
+ self .model_file_name
16
+ )
17
+
18
+ def generate_model_s3_path (self ) -> str :
19
+ return "s3://{}/{}/{}" .format (
20
+ self .model_s3_bucket ,
21
+ self .version ,
22
+ self .model_file_name
23
+ )
24
+
25
+
26
+ @dataclass (frozen = True )
27
+ class Train :
28
+ batch_size : int = 16
29
+ epoch : int = 10
30
+
31
+
32
+ @dataclass (frozen = True )
33
+ class Test :
34
+ batch_size : int = 16
35
+
36
+
37
+ @dataclass (frozen = True )
38
+ class Config (BaseConfig ):
39
+ version : str = "v1.0.0"
40
+ model_path : str = "/ops/models"
41
+ model_s3_bucket : str = "models"
42
+ model_file_name : str = "model.pth"
43
+ train : Train = Train ()
44
+ test : Test = Test ()
Original file line number Diff line number Diff line change
1
+ from dataclasses import FrozenInstanceError
2
+ from unittest import TestCase , main
3
+
4
+ from src .config import BaseConfig , Config
5
+
6
+
7
+ class TestConfig (TestCase ):
8
+ def test_generate_path (self ):
9
+ config = BaseConfig (
10
+ version = "v0.0.0" ,
11
+ model_path = "/opt/models" ,
12
+ model_s3_bucket = "models" ,
13
+ model_file_name = "model.pth"
14
+ )
15
+
16
+ expected_model_path = "/opt/models/v0.0.0/model.pth"
17
+ expected_model_s3_path = "s3://models/v0.0.0/model.pth"
18
+
19
+ self .assertEqual (config .generate_model_path (), expected_model_path )
20
+ self .assertEqual (config .generate_model_s3_path (), expected_model_s3_path )
21
+
22
+ def test_config_can_call_method (self ):
23
+ config = Config ()
24
+ config .generate_model_path ()
25
+ config .generate_model_s3_path ()
26
+
27
+ def test_config_is_immutable (self ):
28
+ config = Config ()
29
+
30
+ with self .assertRaises (FrozenInstanceError ):
31
+ config .version = "v0.0.0"
32
+
33
+ with self .assertRaises (FrozenInstanceError ):
34
+ config .train .epoch = 10000
35
+
36
+ with self .assertRaises (FrozenInstanceError ):
37
+ config .test .batch_siz = 64
38
+
39
+
40
+ if __name__ == '__main__' :
41
+ main ()
You can’t perform that action at this time.
0 commit comments