8000 slice_step_setter_diff_shape · SciSharp/TensorFlow.NET@30f8f67 · GitHub
[go: up one dir, main page]

Skip to content

Commit 30f8f67

Browse files
committed
slice_step_setter_diff_shape
1 parent 8ab1fe9 commit 30f8f67

File tree

4 files changed

+86
-15
lines changed

4 files changed

+86
-15
lines changed

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public NDArray this[NDArray mask]
5757
}
5858
}
5959

60-
60+
[AutoNumPy]
6161
unsafe NDArray GetData(Slice[] slices)
6262
{
6363
if (shape.IsScalar)
@@ -170,9 +170,9 @@ unsafe NDArray GetData(int[] indices, int axis = 0)
170170
}
171171

172172
void SetData(IEnumerable<Slice> slices, NDArray array)
173-
=> SetData(array, data, slices.ToArray(), new int[shape.ndim].ToArray(), -1);
173+
=> SetData(array, slices.ToArray(), new int[shape.ndim].ToArray(), -1);
174174

175-
unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int currentNDim)
175+
unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim)
176176
{
177177
if (dtype != src.dtype)
178178
src = src.astype(dtype);
@@ -181,20 +181,23 @@ unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int
181181
if (!slices.Any())
182182
return;
183183

184+
if (shape.Equals(src.shape))
185+
{
186+
System.Buffer.MemoryCopy(src.data.ToPointer(), data.ToPointer(), src.bytesize, src.bytesize);
187+
return;
188+
}
189+
184190
// first iteration
185191
if(currentNDim == -1)
186192
{
187193
slices = SliceHelper.AlignWithShape(shape, slices);
188-
if (!shape.Equals(src.shape))
189-
{
190-
var newShape = ShapeHelper.AlignWithShape(shape, src.shape);
191-
src = src.reshape(newShape);
192-
}
193194
}
194195

195196
// last dimension
196197
if (currentNDim == ndim - 1)
197198
{
199+
var offset = (int)ShapeHelper.GetOffset(shape, indices);
200+
var dst = data + offset * (int)dtypesize;
198201
System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize);
199202
return;
200203
}
@@ -206,13 +209,56 @@ unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int
206209
var stop = slice.Stop ?? (int)dims[currentNDim];
207210
var step = slice.Step;
208211

209-
for (var i = start; i < stop; i += step)
212+
if(step != 1)
210213
{
211-
indices[currentNDim] = i;
212-
var offset = (int)ShapeHelper.GetOffset(shape, indices);
213-
dst = data + offset * (int)dtypesize;
214-
var srcIndex = (i - start) / step;
215-
SetData(src[srcIndex], dst, slices, indices, currentNDim);
214+
for (var i = start; i < stop; i += step)
215+
{
216+
if (i >= dims[currentNDim])
217+
throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}");
218+
219+
indices[currentNDim] = i;
220+
if (currentNDim < ndim - src.ndim)
221+
{
222+
SetData(src, slices, indices, currentNDim);
223+
}
224+
else
225+
{
226+
var srcIndex = (i - start) / step;
227+
SetData(src[srcIndex], slices, indices, currentNDim);
228+
}
229+
}
230+
}
231+
else
232+
{
233+
for (var i = start; i < stop; i++)
234+
{
235+
if (i >= dims[currentNDim])
236+
throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}");
237+
238+
indices[currentNDim] = i;
239+
if (currentNDim < ndim - src.ndim)
240+
{
241+
SetData(src, slices, indices, currentNDim);
242+
}
243+
// last dimension
244+
else if(currentNDim == ndim - 1)
245+
{
246+
SetData(src, slices, indices, currentNDim);
247+
break;
248+
}
249+
else if(SliceHelper.IsContinuousBlock(slices, currentNDim))
250+
{
251+
var offset = (int)ShapeHelper.GetOffset(shape, indices);
252+
var dst = data + offset * (int)dtypesize;
253+
System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize);
254+
return;
255+
}
256+
else
257+
{
258+
var srcIndex = i - start;
259+
SetData(src[srcIndex], slices, indices, currentNDim);
260+
}
261+
}
216262
}
217263

218264
// reset indices

src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public partial class NDArray
1616
public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs));
1717
[AutoNumPy]
1818
public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs));
19+
[AutoNumPy]
20+
public static NDArray operator %(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mod", lhs, rhs));
1921
[AutoNumPy]
2022
public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs));
2123
[AutoNumPy]

src/TensorFlowNET.Core/NumPy/SliceHelper.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,16 @@ public static bool AreAllIndex(Slice[] slices, out int[] indices)
5555
}
5656
return true;
5757
}
58+
59+
public static bool IsContinuousBlock(Slice[] slices, int ndim)
60+
{
61+
for (int i = ndim + 1; i < slices.Length; i++)
62+
{
63+
if (slices[i].Equals(Slice.All))
64+
continue;
65+
return false;
66+
}
67+
return true;
68+
}
5869
}
5970
}

test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

Lines changed: 13 additions & 1 deletion
< DA5E /tr>
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public void iterating()
118118
}
119119

120120
[TestMethod]
121-
public void slice_step()
121+
public void slice_step_setter()
122122
{
123123
var array = np.arange(32).reshape((4, 8));
124124
var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1;
@@ -131,5 +131,17 @@ public void slice_step()
131131
Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 });
132132
Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 });
133133
}
134+
135+
[TestMethod]
136+
public void slice_step_setter_diff_shape()
137+
{
138+
var array = np.arange(32).reshape((4, 8));
139+
var s1 = np.array(new[] { 100, 200 });
140+
array[Slice.All, new Slice(2, 5, 2)] = s1;
141+
Assert.AreEqual(array[0], new[] { 0, 1, 100, 3, 200, 5, 6, 7 });
142+
Assert.AreEqual(array[1], new[] { 8, 9, 100, 11, 200, 13, 14, 15 });
143+
Assert.AreEqual(array[2], new[] { 16, 17, 100, 19, 200, 21, 22, 23 });
144+
Assert.AreEqual(array[3], new[] { 24, 25, 100, 27, 200, 29, 30, 31 });
145+
}
134146
}
135147
}

0 commit comments

Comments
 (0)
0