8000 feat: data converter manager by bingogome · Pull Request #73 · Tavish9/any4lerobot · GitHub
[go: up one dir, main page]

Skip to content
Prev Previous commit
Next Next commit
[DEV] update to filter new 3.0 fields
  • Loading branch information
bingogome committed Oct 18, 2025
commit 98f4a531218ec793bc9fce6af2affa4f7b8c92a4
55 changes: 54 additions & 1 deletion ds_version_convert/convert_dataset_v21_to_v20.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.utils import cast_stats_to_numpy, load_info, write_info, write_stats

try: # pragma: no cover - numpy is expected but we guard for minimal setups
import numpy as np
except ImportError: # pragma: no cover
np = None

try: # pragma: no cover - compatibility with older package structure
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, V21
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -101,6 +106,53 @@ def _remove_episode_stats_files(root: Path) -> None:
shutil.rmtree(path)


LEGACY_STATS_KEYS = ("mean", "std", "min", "max", "q01", "q99")


def _to_python_scalar(value: Any) -> Any:
if np is not None:
if isinstance(value, np.generic):
return float(value)
if isinstance(value, np.ndarray):
return _to_python_scalar(value.tolist())

if isinstance(value, (list, tuple)):
return [_to_python_scalar(v) for v in value]

if isinstance(value, dict):
return {k: _to_python_scalar(v) for k, v in value.items()}

if isinstance(value, (int, float)):
return float(value)

return value


def _format_stat_value(value: Any) -> Any:
converted = _to_python_scalar(value)
if isinstance(converted, list):
return converted
return [converted]


def _format_stats_for_v20(aggregated_stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
legacy_stats: dict[str, dict[str, Any]] = {}
for feature_name, stats in aggregated_stats.items():
if not isinstance(stats, dict):
continue

formatted_stats = {}
for key in LEGACY_STATS_KEYS:
if key not in stats:
continue
formatted_stats[key] = _format_stat_value(stats[key])

if formatted_stats:
legacy_stats[feature_name] = formatted_stats

return legacy_stats


def _prepare_dataset_root(repo_id: str, output_dir: str) -> Path:
output_path = Path(output_dir).expanduser().resolve()
output_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -146,7 +198,8 @@ def convert_dataset(
raise RuntimeError("No per-episode stats found; cannot reconstruct legacy stats.json")

aggregated_stats = aggregate_stats(list(episodes_stats.values()))
write_stats(aggregated_stats, dataset_root)
legacy_formatted_stats = _format_stats_for_v20(aggregated_stats)
write_stats(legacy_formatted_stats, dataset_root)

info["codebase_version"] = V20
write_info(info, dataset_root)
Expand Down
0