8000 Add config and test_config · suaaa7/samplecode-for-qiita@2e30521 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2e30521

Browse files
committed
Add config and test_config
1 parent 13c876c commit 2e30521

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

python_ci/src/config.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass(frozen=True)
5+
class BaseConfig:
6+
version: str
7+
model_path: str
8+
model_s3_bucket: str
9+
model_file_name: str
10+
11+
def generate_model_path(self) -> str:
12+
return "{}/{}/{}".format(
13+
self.model_path,
14+
self.version,
15+
self.model_file_name
16+
)
17+
18+
def generate_model_s3_path(self) -> str:
19+
return "s3://{}/{}/{}".format(
20+
self.model_s3_bucket,
21+
self.version,
22+
self.model_file_name
23+
)
24+
25+
26+
@dataclass(frozen=True)
27+
class Train:
28+
batch_size: int = 16
29+
epoch: int = 10
30+
31+
32+
@dataclass(frozen=True)
33+
class Test:
34+
batch_size: int = 16
35+
36+
37+
@dataclass(frozen=True)
38+
class Config(BaseConfig):
39+
version: str = "v1.0.0"
40+
model_path: str = "/ops/models"
41+
model_s3_bucket: str = "models"
42+
model_file_name: str = "model.pth"
43+
train: Train = Train()
44+
test: Test = Test()

python_ci/src/tests/test_config.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from dataclasses import FrozenInstanceError
2+
from unittest import TestCase, main
3+
4+
from src.config import BaseConfig, Config
5+
6+
7+
class TestConfig(TestCase):
8+
def test_generate_path(self):
9+
config = BaseConfig(
10+
version="v0.0.0",
11+
model_path="/opt/models",
12+
model_s3_bucket="models",
13+
model_file_name="model.pth"
14+
)
15+
16+
expected_model_path = "/opt/models/v0.0.0/model.pth"
17+
expected_model_s3_path = "s3://models/v0.0.0/model.pth"
18+
19+
self.assertEqual(config.generate_model_path(), expected_model_path)
20+
self.assertEqual(config.generate_model_s3_path(), expected_model_s3_path)
21+
22+
def test_config_can_call_method(self):
23+
config = Config()
24+
config.generate_model_path()
25+
config.generate_model_s3_path()
26+
27+
def test_config_is_immutable(self):
28+
config = Config()
29+
30+
with self.assertRaises(FrozenInstanceError):
31+
config.version = "v0.0.0"
32+
33+
with self.assertRaises(FrozenInstanceError):
34+
config.train.epoch = 10000
35+
36+
with self.assertRaises(FrozenInstanceError):
37+
config.test.batch_siz = 64
38+
39+
40+
if __name__ == '__main__':
41+
main()

0 commit comments

Comments
 (0)
0