10000 remove numpy, too aggressive by geohot · Pull Request #3075 · tinygrad/tinygrad · GitHub
[go: up one dir, main page]

Skip to content

remove numpy, too aggressive #3075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUO
assert isinstance(result.lazydata.base.realized, Buffer)
assert result.lazydata.base.realized.device == "CLANG"
# getting ahead of ourselves, but we can move the Buffer to CPU
assert result.lazydata.base.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
out = memoryview(bytearray(4))
result.lazydata.base.realized.copyout(out)
assert out[0] == 5, "when put in numpy with toCPU, it's 5"

# %%
# == Union[Interpreted, Compiled] (in tinygrad/device.py, code 6/10) ==
Expand Down
20 changes: 7 additions & 13 deletions tinygrad/device.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations
import numpy as np
from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.dtype import DType, ImageDType
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
Expand Down Expand Up @@ -85,16 +84,10 @@ def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyin(self._buf, mv)
return self
@staticmethod
def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
def toCPU(self) -> np.ndarray:
# zero copy with as_buffer
if hasattr(self.allocator, 'as_buffer'):
return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
ret = np.empty(self.size, self.dtype.np)
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
return ret
def copyout(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyout(mv, self._buf)

def _internal_buffer_copy(dest:Buffer, src:Buffer):
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721
Expand All @@ -119,7 +112,8 @@ def _internal_buffer_copy(dest:Buffer, src:Buffer):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
dest.copyin(src.toCPU().data)
src.copyout((mv := memoryview(bytearray(src.size*src.dtype.itemsize))))
dest.copyin(mv)

class _BufferCopy(JITRunner):
# TODO: make wait work
Expand Down
41 changes: 19 additions & 22 deletions tinygrad/dtype.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import NamedTuple, Final, Optional, ClassVar, Set, Tuple, Dict
import numpy as np # TODO: remove numpy
import functools

# TODO: migrate this from NamedTuple -> dataclass
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
fmt: Optional[str]
sz: int = 1
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
def vec(self, sz:int):
Expand All @@ -17,9 +16,9 @@ def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz

# dependent typing?
class ImageDType(DType):
def __new__(cls, priority, itemsize, name, np, shape, base):
return super().__new__(cls, priority, itemsize, name, np)
def __init__(self, priority, itemsize, name, np, shape, base):
def __new__(cls, priority, itemsize, name, fmt, shape, base):
return super().__new__(cls, priority, itemsize, name, fmt)
def __init__(self, priority, itemsize, name, fmt, shape, base):
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
self.base: DType = base
super().__init__()
Expand All @@ -32,7 +31,7 @@ def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape

class PtrDType(DType):
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.fmt, dt.sz)
def __repr__(self): return f"ptr.{super().__repr__()}"

class dtypes:
Expand All @@ -42,26 +41,24 @@ def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfl
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
@staticmethod
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod # NOTE: isinstance(True, int) is True in python
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
int8: Final[DType] = DType(1, 1, "char", np.int8)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
float16: Final[DType] = DType(9, 2, "half", np.float16)
bool: Final[DType] = DType(0, 1, "bool", '?')
int8: Final[DType] = DType(1, 1, "char", 'b')
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B')
int16: Final[DType] = DType(3, 2, "short", 'h')
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H')
int32: Final[DType] = DType(5, 4, "int", 'i')
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I')
int64: Final[DType] = DType(7, 8, "long", 'l')
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L')
float16: Final[DType] = DType(9, 2, "half", 'e')
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
float32: Final[DType] = DType(11, 4, "float", np.float32)
float64: Final[DType] = DType(12, 8, "double", np.float64)
float32: Final[DType] = DType(11, 4, "float", 'f')
float64: Final[DType] = DType(12, 8, "double", 'd')

# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
Expand All @@ -70,9 +67,9 @@ def fields() -> Dict[str, DType]: return DTYPES_DICT

# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp, dtypes.float32)
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', shp, dtypes.float32)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', shp, dtypes.float32)

default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
Expand Down
7 changes: 0 additions & 7 deletions tinygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, math
import numpy as np
from collections import defaultdict
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast
from tinygrad.dtype import dtypes, DType, ImageDType
Expand Down Expand Up @@ -79,12 +78,6 @@ def is_unrealized_contiguous_const(self): return self.base == self and not self.

def schedule(self, seen=None): return create_schedule([self], seen)

@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
ret = LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), dtypes.from_np(x.dtype), op=LoadOps.EMPTY)
ret.realized = Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten())
return ret

def copy_to_device(self, device:str) -> LazyBuffer:
# no COPY
if self.device == device: return self
Expand Down
6 changes: 6 additions & 0 deletions tinygrad/runtime/ops_cpu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import numpy as np
from typing import Callable, Dict, Tuple
from tinygrad.helpers import flat_mv
from tinygrad.dtype import dtypes
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
from tinygrad.device import Interpreted, Allocator

type_map = {np.bool_: dtypes.bool,
np.int8: dtypes.int8, np.uint8: dtypes.uint8, np.int16: dtypes.int16, np.int32: dtypes.int32, np.int64: dtypes.int64,
np.float16: dtypes.float16, np.float32: dtypes.float32, np.float64: dtypes.float64}
inverse_type_map = {v: k for k,v in type_map.items()}

def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
Expand Down
32 changes: 19 additions & 13 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partialmethod, reduce
import numpy as np

from tinygrad.runtime.ops_cpu import type_map, inverse_type_map
from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
from tinygrad.lazy import LazyBuffer, create_schedule
Expand Down Expand Up @@ -41,6 +42,11 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)

def _fromcpu(x: np.ndarray) -> LazyBuffer:
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, type_map[x.dtype.type], "CPU")
ret.realized = Buffer("CPU", prod(x.shape), type_map[x.dtype.type], x.flatten())
return ret

Scalar = Union[float, int, bool]

class Tensor:
Expand Down Expand Up @@ -68,17 +74,17 @@ def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray,
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, list):
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
elif d and all_int(d): dtype = dtype or dtypes.default_int
else: dtype = dtype or dtypes.default_float
# NOTE: cast at the end for the dtypes that do not have a numpy dtype
data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
data = _fromcpu(np.array(data, inverse_type_map[dtype])).cast(dtype)
elif isinstance(data, np.ndarray):
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or type_map[data.dtype], device, data.item())
else: data = _fromcpu(data.astype(inverse_type_map[dtype]) if dtype is not None and dtype in inverse_type_map else data)

# data is a LazyBuffer, but it might be on the wrong device
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
Expand Down Expand Up @@ -134,18 +140,18 @@ def assign(self, x) -> Tensor:
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)

# TODO: these are good places to start removing numpy
def data(self) -> memoryview:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.fmt is not None, f"no memoryview dtype for {self.dtype}"
mv = memoryview(cast(int, self.numel()) * bytearray(self.dtype.itemsize))
cast(Buffer, self.contiguous().realize().lazydata.base.realized).copyout(mv)
return mv.cast(self.dtype.fmt, self.shape)
def item(self) -> Scalar:
assert self.numel() == 1, "must have one element for item"
return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item()
def data(self) -> memoryview: return self.numpy().data
return self.data()[0]

# TODO: this should import numpy and use .data() to construct the array
def numpy(self) -> np.ndarray:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np)
t = self if isinstance(self.device, str) else self.to("CPU")
return t.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
def numpy(self) -> np.ndarray: return np.asarray(self.data())

def to(self, device:Optional[str]) -> Tensor:
if device is None or device == self.device: return self
Expand Down Expand Up @@ -950,5 +956,5 @@ def custom_random(out:Buffer):
Tensor._seed += 1
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
rng = np.random.default_rng(Tensor._seed)
rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=inverse_type_map[out.dtype], copy=False)
out.copyin(rng_np_buffer.data)
0