8000 EXA use openml fetcher in plot_gpr_co2.py example (#12004) · scikit-learn/scikit-learn@2242f4c · GitHub
[go: up one dir, main page]

Skip to content

Commit 2242f4c

Browse files
maxcopelandrth
authored andcommitted
EXA use openml fetcher in plot_gpr_co2.py example (#12004)
1 parent e5333f5 commit 2242f4c

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

examples/gaussian_process/plot_gpr_co2.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
import numpy as np
6767

6868
from matplotlib import pyplot as plt
69-
69+
from sklearn.datasets import fetch_openml
7070
from sklearn.gaussian_process import GaussianProcessRegressor
7171
from sklearn.gaussian_process.kernels \
7272
import RBF, WhiteKernel, RationalQuadratic, ExpSineSquared
@@ -79,37 +79,33 @@
7979
print(__doc__)
8080

8181

82-
def load_mauna_loa_atmospheric_c02():
83-
url = ('http://cdiac.ess-dive.lbl.gov/'
84-
'ftp/trends/co2/sio-keel-flask/maunaloa_c.dat')
82+
def load_mauna_loa_atmospheric_co2():
83+
ml_data = fetch_openml(data_id=41187)
8584
months = []
8685
ppmv_sums = []
8786
counts = []
88-
for line in urlopen(url):
89-
line = line.decode('utf8')
90-
if not line.startswith('MLO'):
91-
# ignore headers
92-
continue
93-
station, date, weight, flag, ppmv = line.split()
94-
y = date[:2]
95-
m = date[2:4]
96-
month_float = (int(('20' if y < '20' else '19') + y) +
97-
(int(m) - 1) / 12)
98-
if not months or month_float != months[-1]:
99-
months.append(month_float)
100-
ppmv_sums.append(float(ppmv))
87+
88+
y = ml_data.data[:, 0]
89+
m = ml_data.data[:, 1]
90+
month_float = y + (m - 1) / 12
91+
ppmvs = ml_data.target
92+
93+
for month, ppmv in zip(month_float, ppmvs):
94+
if not months or month != months[-1]:
95+
months.append(month)
96+
ppmv_sums.append(ppmv)
10197
counts.append(1)
10298
else:
10399
# aggregate monthly sum to produce average
104-
ppmv_sums[-1] += float(ppmv)
100+
ppmv_sums[-1] += ppmv
105101
counts[-1] += 1
106102

107103
months = np.asarray(months).reshape(-1, 1)
108104
avg_ppmvs = np.asarray(ppmv_sums) / counts
109105
return months, avg_ppmvs
110106

111107

112-
X, y = load_mauna_loa_atmospheric_c02()
108+
X, y = load_mauna_loa_atmospheric_co2()
113109

114110
# Kernel with parameters given in GPML book
115111
k1 = 66.0**2 * RBF(length_scale=67.0) # long term smooth rising trend

0 commit comments

Comments
 (0)
0