8000 [RFC] RPC Based Distributed Model Parallel · Issue #23110 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[RFC] RPC Based Distributed Model Parallel #23110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mrshenli opened this issue Jul 19, 2019 · 21 comments
Open

[RFC] RPC Based Distributed Model Parallel #23110

mrshenli opened this issue Jul 19, 2019 · 21 comments
Labels
feature A request for a proper, new feature. module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mrshenli
Copy link
Contributor
mrshenli commented Jul 19, 2019

with @pritamdamania87 @zhaojuanmao @aazzolini @gqchen @pietern @satgera @ezyang @zdevito @suo @manojkris @gchanan @soumith @dzhulgakov @yifuwang @bddppq @joxu-cn @dwarakrajagopal @jspisak

PyTorch currently provides simple APIs for single machine data parallel, distributed data parallel, and single machine model parallel. However, when it comes to distributed model parallel, applications have to build their own scaffold to stitch together local autograd graphs into one global graph. This proposal aims to fill in that gap by providing an RPC-Based distributed model parallel API. In short, applications may run RPC to execute code remotely in the forward pass, and autograd will automatically travel across RPC boundaries in the backward pass.

API

Core Concepts

RRef[T] - (abbreviation ref) A reference to a value of some type T (e.g. Tensor) on a remote worker. This handle keeps the referenced remote tensor value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. It is valid to have a reference to local value as well, and values of type T can be implicitly converted to RRef[T]. This implicit conversion will be critical later to allow the expression of different types of RPC. Think of it like the implicit conversion from std::string to const std::string &. See System Design section for more details about RRef.

ref.owner() # what is the worker this value lives on

v = ref.local_value() # if ref.owner() is local worker, then
                      # this returns the the underlying value, otherwise error.
# you can create a ref to a local tensor
t = torch.rand(3, 4)         
ref2 = torch.RRef(t)

# in TorchScript, T can be automatically converted to RRef[T]
ref3 : RRef[Tensor] = t

Future[T] - (abbreviation fut) a guarantee that at some future point in time the value of type T will be available locally. The action to create T locally is assumed to be scheduled and in-progress. Future is already supported in TorchScript and we are extending this to remote calls.

v = fut.wait() # block the current thread until v is ready

# local cpu task creation returns a future to the computed tensors
fut = torch.fork(lambda x, y: x + y, torch.rand(3, 4), torch.rand(3, 4))

Core Functions

# synchronous
result : T = torch.rpc(on : Worker, remote_callable : Callable, *args)
# asynchronous
result : Future[T] = torch.async_rpc(on : Worker, remote_callable : Callable, *args)
# remote reference
result : RRef[T] = torch.remote(on : Worker, remote_callable : Callable, *args)

Each function above invokes remote_callable on a remote worker. Value types in the args list are copied by value to the remote worker. RRef[T] types in the args list are copied by reference to the remote worker (again see the analogy between std::string and const std::string&).

The synchronous variant copies the result value back, blocking the calling thread until the response occurs. The asynchronous variant returns immediately with a future. The remote knows that the call will expect to receive the value so it will send a message back at some point with the result without further prompting.

The remote reference variant returns immediately with an RRef of the return value. The remote knows that the caller does not expect to receive the result value.

Below shows how these functions are used:

# make some local tensors
a : Tensor = torch.rand(3, 4)
b : Tensor = torch.rand(3, 4)

# define a remote function, visible to all machines.
# type annotations define expected input/output types.
def remote_function(a : Tensor, b : RRef[Tensor]) -> Tensor:
   # 'b' in the type signature is a remote reference, so we must copy it here
   # to use it locally.
   # to_here() is defined later in the syntax sugar section, it synchronously
   # copies the tensor to this worker.
   b_l  : Tensor = b.to_here()
   return a + b_l

# run remote_function on a different device. 
# a is copied by value since it is a Tensor
# b is copied by reference remote machine due to the RRef[Tensor] 
# type annotation in the signature, which causes an implicit conversion to a
# reference type.

# torch.remote always creates an RRef of the result type. 
# It does not wait for the remote's response. 
# There is no implied copy of the tensor data yet.
c : RRef[Tensor] = <
8000
span class="pl-s1">torch.remote("worker1", remote_function, a, b)

# we can explicitly request the data to be copied back here:
c_l : Tensor = c.to_here()

# another example:
def remote_function2(a : Tensor, b : Tensor) -> Tensor:
   return a + b

# Here we call torch.rpc which returns the value directly without
# creating a remote reference.
# we synchronously wait for remote_function2 to return.
c : Tensor = torch.rpc("worker2", remote_function2, a, b)

# When the RPC call is returning a non-reference type, we need to wait for 
# a response from the remote host. To avoid synchronously waiting, use the
# async flag to get a future instead.
c_f : Future[Tensor] = torch.async_rpc("worker2", remote_function2, a, b)
# even before calling wait, the remote knows that the data should be sent back
# to the caller as soon as it is ready.

# force the local thread to wait for the remote's response
c = c_f.wait()


# if you omit type annotations in the remote function, the assumption is that 
# arguments are passed without any implicit conversions
def remote_function3(a, b):
   # no annotations mean that a, b will be Tensor since there is no conversion
   return a + b

c: Tensor = torch.rpc("worker2", remote_function3, a, b)

RRef Forks

Implicit Conversions for RRef Arguments

We allow implicit conversion between T and RRef[T] for arguments of RPC functions. Both the actual and formal parameter can either be a T or an RRef[T], leading to four cases that might occur:

T → T (passing a T to an rpc that accepts a T): the value T is copied by value, and send over the wire as part of the message invoking the RPC

T → RRef[T] (passing a T to an rpc that accepts RRef[T]): The caller constructs a remote reference to the argument, and sends the reference over the wire to the callee. The data is not sent. The callee can then use the reference as a handle to either request the data later or to make further remote calls.

RRef[T] → T (passing an RRef[T] to an rpc that accepts T): The callee expects to get an actual value, so the callee needs to turn the reference into a value. The network behavior depends on where the RRef[T] lives.

  • If the RRef[T] lives on the caller, then the implementation looks up the actual value of T locally and pass it by value along the wire similar to the T → T case.
  • If the RRef[T] lives on the callee, then the implementation just sends the reference and the callee does the lookup locally.
  • If the RRef[T] lives on some third machine, then the caller sends 2 messages. One to the third machine telling it to send the data in the remote reference directly to the callee, and one to the callee telling it to start the RPC and expect this input to be coming from the third machine. This effectively forward value of the RRef[T] to the callee without the caller having to load it or the callee having to request it later

Examples:

def remote_function1() -> Tensor:
    return torch.ones(2)
    
def remote_function2(a : Tensor) → Tensor:
    b = a * 2
    return b

aref : RRef[Tensor] = remote("worker1", remote_function1)

# this local worker will make two RPC calls: one to tell worker1 to send the 
# tensor to worker2, and another one to tell worker2 to expect this Tensor input
# from worker1. remote_function2 will run on worker2 only after it received the 
# tensor from worker1.  
bref : RRef[Tensor] = remote("worker2", remote_function2, aref) 

RRef[T] → RRef[T] (**passing an RRef[T] to an RPC that accepts RRef[T]): **The callee expects an RRef[T], but we must make sure we correctly keep track of references to the value on a remote. So the actual behavior depends on where the RRef[T] lives.

  • If RRef[T] lives on the caller, then we simply pass it to the remote and record that this remote now has a live reference to the value.
  • If the RRef[T] lives on the callee, then we pass it to the remote, and it becomes a local reference on the remote.
  • If RRef[T] lives on some third machine, then we must forward the reference. To do this the caller sends two messages. One to the third machine telling it to create a remote reference and send it to the callee, and one to the callee telling from where to expect the remote. The callee code is not invoked until the remote is transferred to ensure sane reference counting.

Examples:

def remote_function1() -> Tensor:
    return torch.ones(2)
    
def remote_function2(a : RRef[Tensor]) -> Tensor:
    int delta = 10
    return a.to_here() + delta

aref : RRef[Tensor] = remote("worker1", remote_function1)

# this local worker will make two RPC calls: one to tell worker1 to create a 
# remote reference and send it to worker2, and another one to tell worker2 to 
# expect this remote reference input from worker1. remote_function2 code will 
# not run on worker2 until it receives the remote reference from worker1 to 
# ensure proper reference counting.   
bref : RRef[Tensor] = remote("worker2", remote_function2, aref) 

When an RRef[T] goes dead on machine A, a message is sent to the owner of T telling it that the reference from machine A is dead.

Explicit RRef type for return values

The above implicit RRef argument conversion does not apply to return values. If remote_function returns RRef[T], calling it remotely using torch.remote would return RRef[RRef[T]] instead of RRef[T]. This is because when the return value RRef of torch.remote is first created on the caller who does not know the owner of the real data T. T could be stored on the callee of torch.remote, but it could also be on a different worker as callee may also make another remote call within remote_function and return an RRef[T] owned by a different worker. Moreover, the caller is allowed to share the returned RRef with other workers immediately after torch.remote returns. However, as by then, the caller does not know the real owner of T yet, sharing the RRef would break the reference count algorithm.

Examples:

def remote_function3() -> RRef[Tensor]:
    return torch.remote("Worker2", torch.ones, 2, 2)
 
cref : RRef[RRef[Tensor]] = remote("worker1", remote_function3) 

Initialization API

Users may choose communication backend for RPC, and users are responsible for setting up the backend properly before calling the init_rpc method.

# backend: specifies the underlying communication implementation
# init_method: contains the information to initialize/connect a name store to 
#              resolve names
# name: is a unique identifier for the current worker
torch.distributed.init_rpc(backend="pg", init_method="file:///...", name="worker1")

The init_rpc method will create an RpcAgent under the hood and will make the current worker ready to send and receive RPC calls. If you call init_rpc and use the ProcessGroup (pg) backend, it acts as a global barrier, where all the node names as collectively synchronized before continuing. This is not the case if you use a peer to peer backend (e.g. tensor pipes), where calling init_rpc will register the node name in the specified store and start serving.

Applications don’t need to explicitly register functions for remote execution, but we do assume same functions are defined on both caller and callee. This is often true as all workers can import the same set of libraries or even share the same Python script.

Syntax Sugar

Other operations are now implementable using syntax sugar.

Retrieving Value From RRef

# helper private RPC functions
def _identity(v : Tensor) -> Tensor:
    # copy the tensor by value to this remote,
    return v

def _to_here(v : RRef[T]) -> T:
    # take a reference, send it to the device that owns it
    # and have that device return the actual tensor by value
    return v.local_value()

class RRef[T]:
    ...
    # copy a remote tensor to the local worker, sync version
    def to_here(self) -> T:
        return torch.rpc(_to_here, self, on=self.owner())

Builtin Operators

# proxy methods for all builtin functions exist on references for 
# existing TorchScript types like Tensors. They always follow a fixed pattern:
def _mm(a : RRef[Tensor], b : RRef[Tensor]) -> RRef[Tensor]:
    return a.local_value() + b.local_value()
     
class RRef[Tensor]:
    def mm(self : RRef[Tensor], other : RRef[Tensor]) -> RRef[Tensor]:
        on = same_worker(self.owner(), other.owner())
        return torch.remote(on, _mm, self, other)


c : Tensor = a.mm(b).to_here()

Callable and RRef

If RRef[T] holds a callable object T, the application may directly call the RRef which will be translated into torch.remote call to the owner of the callable.

# if T is callable for RRef[T], rref(x) will be translated to calling T(x) 
# on the owner of the RRef
def _call_rref(v : RRef[T], *args):
    return v.local_value()(*args)

class RRef[T]:
    def __call__(self, *args):
        return torch.remote(self.on(), _call_rref, self, *args)

net = torch.remote("Worker1", Net)
net(inputs)

Optimizer and RRef

As models might have remote sub-modules (i.e., RRef[nn.Module]), we should provide an optimizer sugar to handle it. The optimizer sugar (torch.optim.remote) takes a local optimizer constructor, a distributed model parallel model, and an argument list for the local optimizer constructor. The torch.optim.remote recursively creates a local optimizer on every remote sub-module owner, and exposes the same step API as a local optimizer which recursively calls every local optimizer.

class Net1(nn.Module):
    ...

class Net2(nn.Module):
    ...

class DMP(nn.Module):
    def __init__(self):
        self.net1 = dist.remote("worker1", Net1)
        self.net2 = dist.remote("worker2", Net2)
        
dmp = dist.remote("worker0", DMP)
# dist.optimizer creates an optimizer on all RRef owners
optimizer = dist.optimizer(torch.optim.SGD, dmp, lr=0.1)

with dist.autograd.context():
  loss = dmp(inputs)
  dist.autograd.backward(loss)
  optimizer.step()

Model Parallel Training Examples

Multi-Machine Model Training

# 1. load data
inputs_rref = torch.remote("worker1"load_inputs, path_to_inputs) 
labels_rref = torch.remote("worker2"load_labels, path_to_inputs)

# 2. define model
class Net1(nn.Module):
    ...

class Net2(nn.Module):
    ...

class DMP(nn.Module):
    def __init__(self):
        self.net1 = torch.remote("worker1", Net1)
        self.net2 = torch.remote("worker2", Net2)
        
    def forward(self, inputs_rref):
        # RRef[T].__call__(args) is a sugar that translates to 
        # dist.remote(T, RRef.on(), args)
        outputs1_rref = self.net1(inputs_rref)
        outputs2_rref = self.net2(outputs1_rref)
        return outputs2_rref
        
# 3. training, run it where you want to call autograd
def train(inputs_rref, labels_rref):
    dmp = DMP()
    # torch.optim.remote creates an optimizer on every RRef destination
    optimizer = dist.optimizer(torch.optim.SGD, dmp, lr=0.1)
    outputs_rref = dmp(inputs_rref)
    loss = loss_func(outputs_rref.to_here(), labels_rref.to_here())
    autograd_ctx_id = dist.autograd.backward(loss)
    optimizer.step
8000
(autograd_ctx_id)
    
dist.rpc(dev2, train, args=(inputs_rref, labels_rref))

Parameter Server Training

class ParameterServer:
    def __init__(self):
        self.params = torch.zeros(100, 100).to(0)
        
    def get_params(self) -> Tensor:
        return self.params
        
    def add_grads(self, grad: Tensor):
        return self.params += grad.to(0)
        
def train(ps)
    for _ in range(10):
        params = torch.rpc("ps", ParameterServer.get_params, args=(ps, ))
        # run forward and backward
        torch.rpc("ps", ParameterServer.add_grads, args=(ps, params.grad))
        torch.distributed.barrier(group=TRAINER_GROUP)
        
ps = torch.remote("worker1"ParameterServer)
torch.remote("worker2", train, args=(ps,))
torch.remote("worker3", train, args=(ps,))
            

System Design

Distributed Autograd

Basic Idea

In the first version, dist.autograd.backward does not support RRef arguments, but RRef can still help build the autograd graph. The overall idea is as follows.

  • When calling torch.rpc or RRef.to_here(), send and recv autograd functions will be inserted to connect local autograd graphs on multiple workers into one distributed autograd graph.
  • Every distributed backward pass is assigned a globally unique id (autograd_context_id), and every participating worker will keep a dedicate context for it.
  • When the backward computation reaches a recv function, it packs the gradient and the autograd_context_id in the message, and pass it to its send counterpart.
  • Upon receiving a message for a send function in the backward pass, it uses the autograd_context_id in the message to identify which backward pass it belongs to, and uses the gradient in the message to continue autograd computation locally.

Send and Recv Autograd Functions

Let’s start with a simple example where there is just one synchronized RPC call and there is only one tensor passed across worker boundaries. Code is on the left and the autograd graph is on the right where AccumulateGrad autograd functions for leaf nodes are omitted for simplicity.

# the add function should be 
# defined on both workers
def add() -> Tensor:
    a = torch.rand(2, 2)
    b = torch.rand(2, 2)
    c = a + b
    return c
    
# make RPC call from worker0
# to execute add on worker1
c1 = dist.rpc(add, on="worker1")
d = torch.ones_like(c1)
e = c1 * d
e.sum().backward()

Screen Shot 2019-07-16 at 11 19 31 AM

The send and recv autograd functions are inserted during the forward pass, which connect two local graphs into one distributed graph. In the backward pass, the gradient will be passed to the recv autograd function on worker0, and the recv autograd function will then transmit the gradient tensor to worker1’s send autograd function. Then, worker1 can kick off the local autograd engine to resume the backward pass. There are a few more details need to be clarified in this simple example:

  • On worker1, how do we keep the autograd graph alive after the RPC call returns?
    • In short, the distributed autograd engine on worker1 will keep a reference to the send function which can keep the graph alive.
    • Reasoning: The graph can be kept alive by keeping a reference to either tensor C or the send autograd function, as both of them hold a reference to the add autograd function. We choose to keep a reference to the send function instead of tensor C, because C as a non-leaf node produced by add is not needed in the backward pass. It should be freed as soon as possible. It is not memory efficient to hold C alive just because we want to have an entrance point to the autograd graph.
  • In the backward pass, how does recv on worker0 find the correct send on worker1 to talk to?
    • This can be done by assigning a globally unique ID (worker***_id + local send/recv id***) for each send / recv function pair.
  • When can worker1 delete its local autograd graph?
    • send should have the same lifetime as its corresponding recv function. This can be done by sending a message from worker0 to worker1 when recv is destructed on worker0. The recv function is kept alive by the loss tensor. So, conceptually, the global autograd graph will be deleted when the final loss tensor is gone.

Hidden Autograd Path and Circular Dependency

Things can become complicated when an autograd graph contains multiple send/recv pairs. Consider the following example.

# all functions shoud be defined on all workers
def worker0_func(c2: Tensor) -> Tensor:
    g = torch.rand(2, 2)
    h = g + c2
    return h

def worker1_func_top() -> Tensor:
    a = torch.rand(2, 2)
    b = torch.rand(2, 2)
    c = a + b
    return c

def worker1_func_bottom(c: Tensor, e1: Tensor) -> Tensor:
    f = c + e1
    return f

def worker2_func(c1: Tensor) -> Tensor:
    d = torch.rand(2, 2)
    e = c1 + d
    return e

# on Worker3
c_ref = torch.remote(worker1_func_top, on="Worker1")
h1 = torch.rpc(worker0_func, c_ref, on="Worker0")
e_ref = torch.remote(worker2_func, c_ref, on="Worker2")
f1 = torch.rpc(worker1_funct_bottom, c_ref, e_ref, on="Worker1")
i = h1 + f1
i.sum().backward()

Screen Shot 2019-07-16 at 6 20 52 PM

This example highlights two problems that we need to address:

  • Hidden Autograd Path: Existing local autograd engine starts from loss (or all outputs), and do a discovery/marking phase to identify all participating functions before executing the real autograd computation. So that all paths in the autograd graph are known upfront. However, we don’t have this luxury in distributed autograd because some parts of the autograd graph reside on remote workers. For example, when grad arrives at send5, worker1 cannot tell whether send3 will be in the backward pass if it only looks at local information. More specifically, i.sum().backward() will be the same as f1.sum().backward() from worker1’s perspective, but the former involves send3 and the latter does not.
    • To address this problem, we propose to record all globally upstream (upstream in the forward pass, downstream in the autograd graph) send / recv pairs in the forward pass, so that we know exactly which send / recv to wait for in the backward pass.
  • Circular Dependency: there are circular dependencies between worker1 and worker2, i.e., it is impossible to finish autograd computation on one worker before kicking off on another worker. One option is to start autograd computation on worker1 first, and having an autograd thread blocking there waiting for grads for send1, but this is less ideal.
    • To address this problem, we propose to only create the send autograd function and put it in the ready queue when the grad is received. Note that, when computing dependency count for add1, the autograd engine still takes send1 into account, so that the engine will only start computing grads for add1 after both add2 and send1 finish.

Note that we need to record information in the forward pass and do the discovery in the backward pass because we don’t know which send function will be participating in the autograd computation. However, if the application can guarantee that all send functions will receive grad in the backward pass, we can skip all these complexity and have a more efficient version. Both scenarios are useful, so we propose to have two modes:

  • Smart Mode supports running backward on a subgraph of the global autograd graph, but there will be extra overhead in both forward and backward pass.
  • Fast Mode skips dependency recording in the forward pass and graph discovery in the backward pass, but the application needs to guarantee that all send autograd function will receive grad in the backward pass.

The two sections below describe the two algorithms in more details.

Distributed Autograd Algorithm Smart mode

Forward pass:

For every send x:

  1. Find send functions in x’s lineage, by:
    1. Finds all locally reachable recv functions from send x in the autograd graph. In the example above, send2 finds recv1, send4 finds recv3, and send5 finds recv2.
    2. Use those found recv functions to find globally reachable recv functions in send x’s lineage. Note that this can be done, because in step 2 we send enough information from send to recv. In the example above send4 knows send3, and send5 knows send1 and send2.
  2. Then, send x includes ids of its lineage send functions in the message. Intuitively, it means that if there is a grad received for send x, the backward pass must reach all send functions in its lineage as well. It helps a node to determine whether it should wait for a send grad.
# pseudo code to demonstrate how send works in forward
def find_global_lineage(tensor):
    # find local lineage
    recvs = find_recvs(tensor.grad_fn)
    dep_ids = {recv.id for recv in recvs}
    # find global lineage
    dep_ids.update({dep_id for recv in recvs for dep_id in recv.dep_ids})
    return dep_ids

def send(func, tensors, on):
    msg = Message(func)
    for tensor in tensors:
        lineage = find_global_lineage(tensor)
        # connect send to autograd graph
        send = SendFunc()
        send.next = tensor.grad_fn
        # remember the send by its id
        RpcAgent.send_map[send.id] = send
        # coalesce data
        msg.data.append((tensor, send.id, lineage))
    send_msg(msg, on)
    
def recv(func, data, from):
    tensors = []
    for tensor, send_id, lineage in data:
        # use send_id as recv_id, and remember global lineage
        recv = RecvFunc(send_id, lineage)
        tensor.grad_fn = recv
        tensors.append(tensor)
        
    return func(tensors)

Backward pass:

On the node that calls torch.distributed.backward:

  1. Find all send functions in the lineage of the loss tensor. In the above example, it will be all 5 send functions. These ids will be propagated to the recv functions and will be passed to the counterpart send functions accordingly.
    1. Optimizations can be added, e.g., drop unnecessary ids in backward pass to reduce message size.

On every node:

  1. Upon receiving the first message (be it a dedicated discovery message or grad of a send), record its autograd_context_id, and retrieve all participating send ids from the message. Compute dependency count from those send functions (and also from loss grad_fn if loss is on this node). Set dependency count for send functions as 1. If there is any autograd function has dependency count 0, put them into the ready queue.
  2. Upon receiving a send grad, decrement the dependency count of that send by 1, and add it to the ready queue. Note this is done on an RpcAgent thread, and some autograd engine thread will pick up the autograd function for execution.
# pseudo code to demonstrate backward
graph_tasks = {}
def backward(loss):
    global graph_tasks
    
    autograd_context_id = gen_autograd_id()
    lineage = find_global_lineage(loss)
    # these send will participate in the autograd pass
    roots = local_sends.intersection(lineage)
        
    # propagate the autograd_id and deps info to all
    # participating workers. This is non-blocking and can
    # run concurrently with the real backward computation. 
    # This step is not absolutely necessary, but can help other
    # workers to kick off autograd earlier.
    disseminate(autograd_context_id, lineage)

    # below is a handwaving impl to show how it works with local autograd engine
    graph_task = GraphTask()
    graph_tasks[autograd_context_id] = graph_task
    roots.append(loss.grad_fn)
    # setup dependency count properly
    compute_dependencies(GraphRoot(roots), graph_task)
    # insert the task to local engine ready queue. Only the FunctionTask
    # for loss is inserted now, send FunctionTasks will be inserted later
    # when their grad becomes available.
    ready_queue.push_back(FunctionTask(graph_task, loss.grad_fn, ...))
    return autograd_context_id
    
    
def on_grad_send(send_id, grad, autograd_id):
    global graph_tasks
    graph_task = graph_tasks[autograd_id]
    send_func = RpcAgent.send_map[send_id]
    ready_queue.push_back(FunctionTask(graph_task, send_func, grad))

Distributed Autograd Algorithm Fast mode

The problem with the above approach is that including ids in send / recv messages incurs overhead, especially when there are a lot of tensors communicated across multiple workers. And this discovery phase is only necessary when running autograd on subgraph. For example, f1.sum().loss() requires the discovery phase to avoid waiting for send3, but it is easier for i.sum().loss() as all send are involved in the backward. So, we propose to have one additional mode for distributed autograd to bypass send / recv dependency discovery in both forward and backward if all send for non-leaf or requires_grad tensors will receive grad in the backward pass. The mode can be toggled when initializing RPC agents:

# all_requires_grad (bool): If True, the application guarantees that all 
# send functions on non-leaf or requires_grad tensors will receive grad 
# in the backward pass. Hence, we can skip the distributed dependency 
# discovery algorithm (fast mode). If False, run smart mode, where 
# messages beween send/recv will contain dependency ids in both forward
# and backward pass. (default False)
torch.distributed.init_rpc(name, backend="pg", all_requires_grad=False)

Internally, RpcAgent will create a thread-local driver ID, where a driver is the worker that pieces together the autograd graph. In the above example, Worker3 is the driver. In the forward pass, every send function originated from this driver will be tagged with its thread-local driver ID, and this applies to all downstream (upstream in the autograd graph) send functions as well. This can be done by either propagating this driver ID to RPC calls recursively, or do an active driver ID discovery by walking the autograd graph before sending a tensor. If this information is ambiguous, e.g., one send function traces back to two upstream (downstream in the autograd graph) recv functions from two different drivers, it will throw an error. In the backward pass, the thread-local driver id of the loss will be included in the entire autograd execution to identify participating send functions. Note that, in this mode, the application cannot keep two disjoint autograd graphs alive at the same time, as that would break the assumption that all send (originated from the driver) will receive grad in the backward pass.

Concurrent distributed Backward passes

A = torch.rand(2, 2)
B = torch.rand(2, 2)
    
# on all workers
def add() -> Tensor:
    global A, B
    return A + B
    
# on worker0
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()

# on worker1
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()

In the above example, there are two concurrent backward passes triggered by worker0 and worker1 respectively, and both will reach worker2. To avoid race, the distributed autograd engine will use the globally unique autograd_context_id to create a dedicated context on every participating worker. Later, pass this autograd_context_id to optimizer to apply gradients. More concretely, this would work as follows:

  1. Compute all the leaf nodes in the autograd graph.
  2. As part of running distributed backwards, use the outputs parameter of the autograd engine to avoid executing AccumulateGrad for the leaf nodes we have and instead return the appropriate output_edges to execute for accumulating gradients.
  3. Store the output_edges with the autograd_context_id. This would ensure multiple backward passes won't accumulate gradients in the same context.
  4. This completes the backward pass and gradients are accumulated in the autograd engine per autograd_context_id.
  5. Now we run the optimizer on each of the worker nodes and pass the autograd_context_id to the optimizer.
  6. The optimizer applies all the gradients to the leaf nodes that we computed originally.
  7. The context and enclosing gradients should be destroyed when the autograd_context_id is destructed on the caller of backward().

Some pseudo-code to illustrate this:

optimizer = dist.optimizer(model)
loss = model(inputs)
bw_ctx_id = dist.autograd.backward(loss, timeout=60) # timeout of 60s
optimizer.step(bw_ctx_id)

RRef

(more details are described in #26759)

RRef is an important concept for building a distributed autograd graph. Each RRef is owned by a single worker (i.e., owner) and can be used by multiple users. The owner stores the real data referenced by its RRefs, and keeps track of the global reference counts for its RRefs. Every RRef can be uniquely identified by a global id ref_id, which is assigned at the time it is first created either on a user or on the owner.

The owner only keeps one RRef instance for each data object, while users can fork as many RRef instances as necessary. All usage on the owner should retrieve the RRef instance using the globally unique ref_id. A fork of RRef will be created when it is used as an argument or return value in a RPC call, but users don't need to worry about forking/forwarding and reference counting (RC) RRefs. These will be handled transparently, and every fork will also have its own fork_id, which is guaranteed to be unique across all RRef instances for the same data object.

RRef needs to support fast and scalable RPC. Hence, in the RC design, we avoid using any global master to keep RRef states. Besides, when worker X invokes RPC on worker Y, Y should be able to start immediately after receiving the RPC request, without waiting for any third-party owner Z (unless Y needs to pull real data from Z), even if neither X nor Y owns the RRef. We propose the following algorithm:

  1. If the owner is the RPC caller, the owner will update RC for the RRef accordingly.
  2. If the owner is the RPC callee, the owner will drop the new fork, and use the unique RRef id in the fork to access its singleton local RRef instance.
  3. If the RPC is between two users:
    1. The caller sends an RPC message to the callee, and also notifies the owner on the new fork.
    2. The owner, upon receiving the notification, updates its local RC and then tells the callee the new fork is now known by the owner.
    3. The callee can starts executing the RPC as soon as it receives the RPC message from the caller, and does not need to wait for the message from the owner. However, it cannot delete its local RRef fork until owner's message arrives.

Reference Count

The right time to delete an RRef on owner is when there are no living forks on any user and Python GC also agrees to delete the RRef instance on the owner. The tricky part is to determine if there are any living forks.

A user can get a fork in three situations:

  1. Receiving a fork from the owner.
  2. Receiving a fork from another user.
  3. Creating a new RRef fork owned by another worker.

#1 is the simplest case where the owner initiates the fork, and hence it can easily increase local RC. The only requirement is that any fork must notify the owner before destruction. Hence, we need the first guarantee:

  • G1. The owner will be notified when any fork is deleted.*

Note that the notification might come delayed or out-of-order.

With #2 and #3, it is possible that the owner only partially knows the RRef fork graph or not even knowing it at all. For example, the RRef could be constructed on a user, and before the owner receives the RPC call, the creator user might have already shared the RRef with other users, and those users could further share the RRef. One invariant is that the fork graph of any RRef is a tree rooted at the owner, because forking an RRef always creates a new RRef instance, and hence every RRef has a parent. One nasty detail is that when an RRef is created on a user, technically the owner is not its parent but we still consider it that way and it does not break the argument below.

The owner's view on any node (fork) in the tree has three stages 1) unknown → 2) known → 3) deleted, and the owner's view on the entire tree keeps changing. The owner deletes its RRef instance when it thinks there is no living forks, i.e., all the forks could be either indeed deleted or unknown. Therefore, the dangerous case is when some forks are unknown and others are deleted. We only need a simple guarantee to prevent this situation:

*G2. No fork x can be deleted on a user before the owner knows x’s parent fork.
*
This works because owner's view on x can only change from known to deleted when x's parent is known or deleted. If the parent is known, owner will not delete local RRef. If the parent is deleted, this rule recursively applies to the parent's parent, until it reaches the root (owner). To implement the guarantee, we only need to make the caller include its own fork_id when notifying the owner on a new fork.

G1 and G2 guarantee correct RC, but does not prevent a user deleting before finishes its own prior RPC calls using that RRef fork. This should be OK, because when the caller deserializes the RPC message, it would hold a reference () to that RRef, preventing it from been deleted.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera

@mrshenli mrshenli added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module feature A request for a proper, new feature. labels Jul 19, 2019
@gqchen gqchen pinned this issue Jul 20, 2019
@gqchen gqchen unpinned this issue Jul 20, 2019
mrshenli added a commit that referenced this issue Jul 24, 2019
…rators"

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#23228 sync and async torch.distributed.rpc for builtin operators**

Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation for #23110
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Jul 25, 2019
…rators"

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#23228 sync and async torch.distributed.rpc for builtin operators**

Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation for #23110
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Jul 25, 2019
…rators"

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#23228 sync and async torch.distributed.rpc for builtin operators**

Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation for #23110
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Jul 25, 2019
…rators"


Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation for #23110
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Jul 25, 2019
…rators"


Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation for #23110
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Jul 26, 2019
…rators"


Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation for #23110
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Jul 26, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation for #23110
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 2, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserializati
8000
on for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
mrshenli added a commit that referenced this issue Aug 5, 2019
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs.

* have a minimum working and testable RPC implementation.
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.



Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
pritamdamania87 pushed a commit that referenced this issue Oct 8, 2019
Pull Request resolved: #27022

This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91558984

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)
rohithkrn added a commit to ROCm/pytorch that referenced this issue Oct 9, 2019
* Implement C++ API version of torch.nn.functional.one_hot (#27081) (#27177)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27177

Add support for F::one_hot C++ function.

Test Plan:
Added 3 new tests to verify API is working

Imported from OSS

Differential Revision: D17697934

fbshipit-source-id: a8127fb87c00daa119bb92a5702bc4bbba48290d

* Refactor torch::jit::script::Module::register_* API. (#27189)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27189

Conceptually, Module is just a view over ClassType and ivalue::object.
register_ methods are the only methods that are exception from this:
they provide an API not available on ClassType or object directly. This
PR ports this API to ClassType and makes Module truly just a view over
those two.

Test Plan: Imported from OSS

Differential Revision: D17703533

Pulled By: ZolotukhinM

fbshipit-source-id: 2cdb9fb486b3fb8527986483c7f34be7bd59fabf

* Add c10_experimental ops to BC check white list (#27235)

Summary:
experimental ops doesn't provide bc guarantee.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27235

Reviewed By: hl475

Differential Revision: D17723292

Pulled By: houseroad

fbshipit-source-id: 644ae34d130418a810e0f9d802fa25f6e34c5ccf

* Rename _intrinsic to intrinsic

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27194

Test Plan: Imported from OSS

Differential Revision: D17704957

Pulled By: zafartahirov

fbshipit-source-id: 46f02d129aa77c3047b2a6c606bfadd831a6b0fc

* Allow set for qconfig for dynamic_quantize

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27181

Test Plan: Imported from OSS

Differential Revision: D17717482

Pulled By: jamesr66a

fbshipit-source-id: f3930fc87831cbdcf4390cd769c594bb13f5cd81

* Fix reprs for _intrinsic modules

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27184

Test Plan: Imported from OSS

Differential Revision: D17717481

Pulled By: jamesr66a

fbshipit-source-id: 4bd72bcd42191d9b21d03f5bb6698198dbffffda

* skip all rpc and dist autograd spawn tests for <PY36 (#27191)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27191

skip rpc and distautograd spawns tests for <python 3.6
ghstack-source-id: 91231565

close #27157

Test Plan: unit tests

Differential Revision: D17697368

fbshipit-source-id: bb8cf1f47de41f9d350fd60afe37fece293d8680

* Add send and recv backward functions for builtin operators RPC. (#25527)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25527

Master GH issue: https://github.com/pytorch/pytorch/issues/23110.

This change builds upon https://github.com/pytorch/pytorch/pull/24876 and
provides all the autograd hooks needed for a forward pass with distributed rpc
for builtin operators. This change does not address distributed rpc for python
UDFs and that will be addressed in follow up PRs.

Summary of changes:
1. Attach send autograd functions when a request is sent from the client and
response is sent from the server.
2. Attach receive autograd functions when a request is received on the server
and a response is received on the client.
3. Generate a globally unique autograd_message_id for each send/recv autograd
function pair to uniquely identify them.
ghstack-source-id: 91240466

Test Plan: unit tests.

Differential Revision: D17148077

fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233

* Rename jit Function to ScriptFunction

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27219

Test Plan: Imported from OSS

Differential Revision: D17715306

Pulled By: albanD

fbshipit-source-id: d11a7634dbee6a885c7177b240958e5aed2544f3

* Make cpp-backed jit classes appear as being in torch.jit

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27220

Test Plan: Imported from OSS

Differential Revision: D17715305

Pulled By: albanD

fbshipit-source-id: 574704ad23ece6da7aa2780b78867307bef523cc

* Avoid configuring ROCm if USE_CUDA is on. (#26910)

Summary:
Move the resolution of conflict between `USE_CUDA` and `USE_ROCM` to CMake as to effectuate:

- `USE_CUDA=ON` and CUDA is found, `USE_ROCM=ON` and ROCM is found --> fatal error
- Either `USE_CUDA=ON` and CUDA is found or `USE_ROCM=ON` and ROCM is found --> The respective GPU feature is ON
- Otherwise no GPU support
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26910

Differential Revision: D17738652

Pulled By: ezyang

fbshipit-source-id: 8e07cc7e922e0abda24a6518119c28952276064e

* Revert "Add std::variant backport as c10::variant (#26836)" (#27277)

Summary:
This reverts commit 0cd188035a27fc38ce1e8eee205f6d47cd7650e6.

As reported by jerryzh168 and pritamdamania87, mpark::variant doesn’t compile with gcc 7.3.1 on fb devserver and throws error similar to https://github.com/mpark/variant/issues/43. (However, it doesn’t fail with gcc 7.3.1 in OSS CI, based on https://circleci.com/api/v1.1/project/github/pytorch/pytorch/2995606/output/107/0?file=true)
A plausible workaround is to upgrade devserver to devtoolset-8, but that would in turn causes CUDA build to complain:
```
/usr/local/cuda/bin/../targets/x86_64-linux/include/crt/host_config.h:119:2: error: #error -- unsupported GNU version! gcc versions later than 7 are not supported!
 #error -- unsupported GNU version! gcc versions later than 7 are not supported!
```
(Thanks pritamdamania87 for the report!)

The solution for now is to revert the mpark::variant addition, and I will find alternatives that will work with gcc 7.3.1 on fb devserver.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27277

Differential Revision: D17739804

fbshipit-source-id: ad945b3d86ab7ddbff58f4ecab95e0e1ac725ae9

* Implement LpNorm regularizer to be used on the inputs for feature importance (#26376)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26376

* Create the new dense_feature_reg (FCInputLpNorm) for feature importance to be applied to the fully-connected layer for feature-importance.

Test Plan: * Unit test located in: `caffe2/caffe2/fb/dper/layer_models/tests/split_1/sparse_nn_test.py`

Reviewed By: un-disclosed

Differential Revision: D17360361

fbshipit-source-id: 1a0e119eeb17199a13dfffe58b3036ea4255e301

* Provide (but skip) 3.5 job by default on all PRs. (#27293)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27293

This doesn't turn on 3.5 signal, but it makes it so that [test all]
will include it if you do request it.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D17738741

Pulled By: ezyang

fbshipit-source-id: 2b1af4d7bf26fd84a593fde292d6bfa2aabc1148

* more profiler changes in C++ before enabling checkScript changes

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26909

Differential Revision: D17683632

Pulled By: Krovatkin

fbshipit-source-id: 5d36c3c4cf7411c56485ef19fe59262b9f8b45b2

* Fix segfault while printing value type for an error msg in emitListComprehension

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27261

Differential Revision: D17740159

Pulled By: Krovatkin

fbshipit-source-id: 90439282aea14d8634eb41ffece5b6320d615fa7

* Factored out the default mappings

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27164

Test Plan: Imported from OSS

Differential Revision: D17694475

Pulled By: zafartahirov

fbshipit-source-id: df8df5f7d66062ed35da957064a31344e1d3c961

* Add memory format argument to the `clone` operator (#27106)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27106

Adds memory_format option to the `clone` operator.

Introduce new `clone` behavior if used with `input_t.clone(memory_format=torch.preserve_format)`:
1) If tensor is non-overlapping and dense - output tensor will have the same strides as input tensor.
2) If not (1) and tensor is stored in the channels last format, output tensor going to have channels last format.
3) Output tensor is going to be contiguous in all other cases.

 ---
Dense tensor is the tensor that store values in a contiguous block of memory.
Non-overlapping tensor is the tensor in which elements occupy individual non-repetitive memory.

Test Plan: Imported from OSS

Differential Revision: D17699357

Pulled By: VitalyFedyunin

fbshipit-source-id: 5ae1537c2aca1abf0bf1eec4416846129c156f66

* Extract version to version.txt (#27149)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27149

Extract version to version.txt and add reading version logic to setup.py and fb/torch_version.py
ghstack-source-id: 91271883

Test Plan: N/A

Reviewed By: gchanan, ezyang

Differential Revision: D17689307

fbshipit-source-id: 21899502027cec71b63d9dc151e09ff5ff3f279d

* add AutoNonVariableTypeMode for USE_STATIC_DISPATCH on JIT->ATen path (#27274)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27274

This is yet another fix to address #26764.

PR #26908 toggles NonVariableTypeMode in ATen dispatcher, which is where
USE_STATIC_DISPATCH takes place thus it's most logically sound place to do
such tweaks.

However, we observed nontrivial perf regression due to this fix. Turns out
the numel() tensor method gets called in several for-loops thus incurs ~7M
thread_local updates in a single forward call:
```
7173330 numel
    558 size
    416 q_scale
    302 _empty_affine_quantized
    288 contiguous
    257 q_zero_point
    216 qscheme
    173 empty
    110 set_
    105 as_strided
    104 permute
...
```

As numel() is not called from a single place so a natural workaround is to
update function_wrapper.py so that it only adds the guard on gen_namespace_function()
case and ignore the gen_tensor_method() case. But some tensor methods are actually
being called from JIT side directly (e.g. "aten::eq_" -> "(self).eq_") so the
only "band aid" left on the table is to insert guard on JIT->aten path as originally
did on #26868 - this is a simplified version of it as it doesn't hurt to extend the
NonVariableMode scope a little bit to also cover stack drop/pack calls.

On Android we only expose JIT API so we don't need worry about TensorMethods being
called directly. On iOS we don't provide a wrapper yet but we can mention this caveat
in the doc. Hopefully by the time it's widely used we can finish Variable/Tensor
unification and remove all these hacks.

Test Plan:
- Verified it runs quantized/fp32 MobileNetV2 models;
- Verified it fixes the perf regression (revert #26908 separately);

Differential Revision: D17732489

Pulled By: ljk53

fbshipit-source-id: c14ca66aebc6b6f17ad6efac7ca47f9487c98de5

* Updating submodules

Summary:
GitHub commits:

https://github.com/pytorch/fbgemm/commit/8786c0819029c076b0e28320e880ba3ac192ea8b

Test Plan: n/a

Reviewed By: zpao

fbshipit-source-id: 9c04a2ba7cc2166db0203f186ece261ca8b186dd

* Avoid calling tensor.numel() in for loops (#27298)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27298

PR #26908 toggles NonVariableTypeMode in ATen dispatcher, which is where
USE_STATIC_DISPATCH takes place.
This causes an issue with numel() as it gets called through the dispatch mode and probably not getting inlined.
Also the thread local state is expensive to read/write so many times and this kills perf.

PR #27274 is another approach to fix this and has more details.

Test Plan:
Quantized mobilenetV2 perf before this change
Main run finished. Milliseconds per iter: 28.6782. Iters per second: 34.8696

Perf after this change
Main run finished. Milliseconds per iter: 22.2585. Iters per second: 44.9267

Imported from OSS

Differential Revision: D17742565

fbshipit-source-id: 43c6045cc001c46916ba339555c9d809a2537eff

* Fix circle CI

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27307

Test Plan: Imported from OSS

Differential Revision: D17746444

Pulled By: xta0

fbshipit-source-id: ed37f91921f1ea7db6c63ba69f04883856341c39

* Update the link for iOS demo app in README.md (#27145)

Summary:
Update the link for iOS demo app in README.md
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27145

Differential Revision: D17746591

Pulled By: xta0

fbshipit-source-id: 6f49a0daddc8b79804e1b8487ba1db3807a3f481

* Allow use cpu_serial_kernel with void-lambda (#27271)

Summary:
Currently we use CPU_tensor_apply1 to loop through the tensor in single thread and aggregate data:
```
// compute variance per input
 accscalar_t var_sum = 0;
 CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
    var_sum += (i - mean) * (i - mean);
 });
```
and we don't have the ability to use TensorIterator for this.

```
accscalar_t var_sum = 0;
auto iter = TensorIterator::unary_op(self, self);
  cpu_serial_kernel(iter, [&](scalar_t i) -> scalar_t {
        var_sum += (i - mean) * (i - mean);
  return a; //Unable to set value back, because self should be const
});
```

This PR should resolve this problem and allow to use void-lambda:
```
auto iter = at::TensorIterator();
iter.add_input(in);
iter.build();
accscalar_t var_sum = 0;                                                            \
at::native::cpu_serial_kernel(iter, [&](scalar_t i) -> void {
   var_sum += (i - mean) * (i - mean);
});
```

In the future it make sense to change Reduction part and allow to reduce to a scalar, not just to a tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27271

Differential Revision: D17743310

Pulled By: ifedan

fbshipit-source-id: a149751f2d671aefd3ed84bd50b2c0543a63b701

* Move the CUDA implementation of log10 to ATen. (#26733)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26733

Close #24587

Test Plan: Imported from OSS

Differential Revision: D17606981

Pulled By: VitalyFedyunin

fbshipit-source-id: 732f07b981287da3ca235b272b7b6f78144f8ebe

* Mention magma-cuda101 package in install instructions (#27325)

Summary:
There is a magma package for the newest CUDA verson (10.1), mention it here lest someone try to mistakenly use the version for CUDA 10.0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27325

Differential Revision: D17749535

Pulled By: soumith

fbshipit-source-id: 2d34a7af1218e6157935bfd5e03f4d2c0f00f200

* C++ API parity: TensorTest.BackwardNonScalarOutputs

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27314

Test Plan: Imported from OSS

Differential Revision: D17746371

Pulled By: pbelevich

fbshipit-source-id: 246fae22a60ed9a6d7b9843239b4b3391cc9dc3e

* Fix build (#27318)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27318

Fix TBB build
USE_TBB=1 ATEN_THREADING=TBB python setup.py develop install --cmake

Test Plan: Imported from OSS

Differential Revision: D17747449

Pulled By: ilia-cher

fbshipit-source-id: 421f362bd10f3be34bffe86ae4f26e8f1c15f1a4

* Relax restrictions on set_num_threads (#27190)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27190

Allow set_num_threads to be called multiple times in case of TBB
parallel backend

Test Plan:
BUILD_BINARY=1 USE_TBB=1 ATEN_THREADING=TBB python setup.py develop
install  --cmake
./build/bin/test_parallel
./build/bin/thread_init_test

Reviewed By: kostmo

Differential Revision: D17704236

Pulled By: ilia-cher

fbshipit-source-id: 274380795e78ba417301c5faa18c9e9d3198bd5e

* Migrate the cpu and gpu implementations of resize nearest 3D from vision to caffe2

Summary: As title. Fix the build failures in unicorn-build-restrictions as discussed in D17330625

Test Plan:
buck test mode/opt caffe2/caffe2/quantization/server:resize_nearest_3d_dnnlowp_op_test

In vision libs, no need to explicitly add dep to resize 3d op as the caffe2_cpu dep is added by default.

Reviewed By: stephenyan1231

Differential Revision: D17676082

fbshipit-source-id: c034ab67a9078f72077b396991ffb9e54e6ab40b

* Add method add_hparams to API doc (#27344)

Summary:
Adds the method `add_hparams` to `torch.utils.tensorboard` API docs. Will want to have this in PyTorch 1.3 release.

cc sanekmelnikov lanpa natalialunova
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27344

Differential Revision: D17753689

Pulled By: orionr

fbshipit-source-id: cc8636e0bdcf3f434444cd29471c62105491039d

* Support interface python assignment as an attribute (#26734)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26734

This PR added the python assignment for interface as an attribute in the
module, it enables any object that implicitly inheriting the specific
interface to be able to be assigned to the interface type in python.

Serialization support for interface/class assignment will be done in the
follow up PR

Test Plan: Imported from OSS

Differential Revision: D17742708

Pulled By: wanchaol

fbshipit-source-id: a0a2d8c74b60ed3fa6c05e1b0d49b7ad1abc670b

* Skip tests that use numpy if it's not present

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27165

Pulled By: driazati

Differential Revision: D17695078

fbshipit-source-id: d25c920f4c43285028537f88761d47a2c9db7b8f

* Add Python RRef as args and return value (#25499)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499

See #23110 for model parallel design details, and #26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.

Limitations:
1. No implicit type conversion yet. (#27099)
2. No failure handling and retry. (#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (#27098)
4. Internal RRef control messages are not idempotent yet. (#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (#27096)

Main changes:

1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5.  Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.

Test Plan:
Imported from OSS

buck test mode/dev-nosan //caffe2/test:rpc_fork

Differential Revision: D17184146

Pulled By: mrshenli

fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265

* Set MINIZ_NO_TIME to avoid computing localtime on each pickle/unpickle (#27268)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27268

For small pickle/unpickle, we spend a disproportionate amount of time in
time functions - roughly 23% in __tzset() for unpickle case.

We're currently not using the .m_time currently, though we can add this feature
back if it's ever needed.

An alternative would be to -DMINIZ_NO_TIME in compiler_flags, but we would
need to also consistently # define MINIZ_NO_TIME in any .cpp including this .h,
since this # define modifies the struct length in an unfortunate manner.

Test Plan:
buck test mode/dev-nosan caffe2/test/...
Run benchmark:
 buck-out/opt/gen/caffe2/torch/fb/distributed/thriftRpcBackend/test/ThriftRpcAgentBench

Differential Revision: D17724198

fbshipit-source-id: b44a0217b1d9f8ce6c0f24297f59045c7cadf4b1

* Add a test case to RpcTest, check src/dst (#27322)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27322

# Problem

Existing test cases are too symmetric, so that didn't detect this error, request sent to the wrong worker.

Because of wrong `worker_names` setup, worker0 sends request to itself, while it should had sent to worker1.

# Solution

Add a test case, letting the dst side to check if it's an request from the expected src.
ghstack-source-id: 91299312

Reviewed By: satgera

Differential Revision: D17069062

fbshipit-source-id: ef7a532dd497bfc0f0ee8446fcd5d29656aaf175

* Update to ROCm 2.8 (#27337)

Summary:
New docker images built with tag 324.

Related jenkins changes:
https://github.com/pytorch/ossci-job-dsl/commit/83ec81335742e66b02af90b7c74021b8792fc63f
https://github.com/pytorch/ossci-job-dsl/commit/aa235a14c82db69d0544cd8fc1da03ef9a50096e

Triggered CI runs:
https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-devtoolset7-rocmrpm-centos7.5-trigger-test/48682/
https://ci.pytorch.org/jenkins/job/pytorch-builds/job/py2-clang7-rocmdeb-ubuntu16.04-trigger/55638/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27337

Differential Revision: D17753827

Pulled By: bddppq

fbshipit-source-id: 2c3f77b0b7c680013c7cc6d7953fe0da4922fe48

* add sdk support for xcodebuild script

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27358

Test Plan: Imported from OSS

Differential Revision: D17757389

Pulled By: xta0

fbshipit-source-id: ed8e470b9c6329b96297ee7c65ba08759251baad

* export remainder (#24410)

Summary:
Added ONNX export support for torch.remainder and torch.fmod
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24410

Reviewed By: hl475

Differential Revision: D17466791

Pulled By: houseroad

fbshipit-source-id: afe6519e5f370824e3b4a45b69036a7260fb72cf

* Replacing the skip_list with white_list in the qconfig propagation

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27183

Test Plan: Imported from OSS

Differential Revision: D17700548

Pulled By: zafartahirov

fbshipit-source-id: 18e6ffbda496b14ac1da1783f928ad539cdb1d16

* Show a warning that not all dir members of quantized work. (#27339)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27339

This PR just shows a warning message.
Eventually we will show a correct __dir__

Test Plan: Imported from OSS

Differential Revision: D17751333

Pulled By: zafartahirov

fbshipit-source-id: e9bc62fd8dd0147979291d0aac3f1afe5b8c7a9f

* improve error messages when a method or attribute is missing (#27110)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27110

Previously missing methods on some types like tensors would talk about
'builtins' which are only a thing inside of the compiler. Furthermore,
the error would only occur when the builtin was applied and it was discovered
that no builtin existed. This changes the error message so that it
discovers that method on our builtin types does not exist on attribute lookup.

Test Plan: Imported from OSS

Differential Revision: D17677616

Pulled By: zdevito

fbshipit-source-id: 2f7cf6c6093a9c832569c44f4b1044a2e56fe205

* refactor extra sugared values (#26270)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26270

We've accumulated a lot of sugared values whose only purpose is
to be instanced-checked against in emitApplyExpr. I need to add
another one to insert an unchecked_cast, and do not want to continue
the pattern. This creates an abstraction for this concept (SpecialFormValue),
and removes all the unneeded sugared values. There is no functionality
change here just a bunch of code movement in compiler.cpp

Test Plan: Imported from OSS

Differential Revision: D17412854

Pulled By: zdevito

fbshipit-source-id: 15877c91decaea5a00d1fe737ed2d0f0f8a79a28

* Minor readability fixes to C++ documentation (#27338)

Summary:
Changed `yieldings` to `yielding`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27338

Differential Revision: D17758406

Pulled By: yf225

fbshipit-source-id: 1633834a6ad80449c061ebc330ac24f3e42f5506

* Choose num_threads in parallel_for based on GRAIN_SIZE (#26963)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/24080, Continuation of https://github.com/pytorch/pytorch/issues/26886

What soumith said in https://github.com/pytorch/pytorch/pull/26886#issuecomment-535760635 seems plausible
> I wonder if it has to do with `#pragma omp parallel num_threads(num_threads)` which has unintended consequences, where even if `num_threads=1`, entering an omp block inside an omp block results in bad behavior.

I know for a fact that gcc's openmp doesn't start the thread pool when given `num_threads(1)` but it seems clang behaves differently.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26963

Differential Revision: D17626981

Pulled By: soumith

fbshipit-source-id: 484ffe6cc172382bb5ff49ce1fceda7eba20a512

* Enable Python3.6 PyTorch ROCm CI

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27353

Differential Revision: D17758495

Pulled By: bddppq

fbshipit-source-id: 95e329bc30f092e4093a33c408f1647b803d9983

* Fixes PackedSequence.to (and unifies PackedSequence conversions) (#27245)

Summary:
PackedSequence.to(device) incorrectly places one of three tensors on the device and leaves the other two tensors where they are. If these devices are distinct then further operations on PackedSequence will fail. This behavior is inconsistent with Tensor.to and PackedSequence's behavior when .cuda() is called.

Additionally, PackedSequence defines multiple other conversion functions that were independently and inconsistently implemented.

This PR unifies all implementations and makes the PackedSequence.to behavior more consistent with Tensor.to. It is not completely consistent per comments. test_device_mask in test_nn.py is updated to validate the new functionality.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27245

Differential Revision: D17757850

Pulled By: mruberry

fbshipit-source-id: 58f0bd40f1aa300fb0a91ee743483d645f977dc5

* Makes test_cuda.py's generated tensor op tests generic (#27210)

Summary:
- The tensor op tests generated in test_cuda.py are now generic and appear in test_torch,py
- Data previously held in auxiliary data structures and files, like test_cuda_ignores.txt, is inlined

Previously the tensor op tests used several auxiliary data structures, a file, and exception handling to filter the test suite. If a function wasn't implemented, for example, that exception would be caught. This let functions like trigamma, which isn't callable, appear to be tested. See https://github.com/pytorch/pytorch/issues/27230. Filtering from additional data stores is error prone, too. It requires developers understand what data stores are used and how they're used. The existing sources are also sometimes incorrect. The txt file claims that dist_ doesn't work on half tensors, for example, but the updated tests verify it does.

In addition to making these tests generic, this PR removes those auxiliary data structures and does not catch any exceptions. Exceptions are errors. (This also means that if something implemented breaks it will now report as an error. Previously the test suite would have reported a pass.) The test infrastructure was also simplified to not perform computations with CPU half tensors since they do not support many operations. This introduces a float<->half conversion quirk but eliminates awkward functions that would first convert cpu tensors to float, perform an operation, and convert them back.

With this change test_cuda.py is almost entirely CUDA-specific.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27210

Differential Revision: D17757907

Pulled By: mruberry

fbshipit-source-id: b3c191c379667b1a7d5361087bdf82f397f77f65

* Remove six dependency (#27282)

Summary:
https://github.com/pytorch/pytorch/pull/27136 added a dependency on `six`, which is not available by default and is not marked as a dependency on PyTorch binaries, causing torchvision CI to break, see https://circleci.com/gh/pytorch/vision/20778?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link for example.

This PR use `torch._six` instead of `six` as a replacement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27282

Reviewed By: lerks

Differential Revision: D17737561

Pulled By: fmassa

fbshipit-source-id: 7dcd0cc2c8bab27b8f4535f664f60388818d3497

* Make `align_to` method-only. (#27304)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27304

The ellipsis version of `align_to` only works if it is called as a
method. To prevent any confusion, this PR disables `torch.align_to` (but
keeps `Tensor.align_to`.

Test Plan: - [namedtensor ci]

Differential Revision: D17743809

Pulled By: zou3519

fbshipit-source-id: cf5c53dcf45ba244f61bb1e00e4853de5db6c241

* Remove CUDA_VERSION from Python script (which has already been detected in CMake) (#27316)

Summary:
(Intentionally left blank)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27316

Differential Revision: D17762715

Pulled By: ezyang

fbshipit-source-id: 044c0ea6e8c2d12912c946a9a50b934b5253d8c8

* Revert D17743310: [pytorch][PR] Allow use cpu_serial_kernel with void-lambda

Test Plan: revert-hammer

Differential Revision:
D17743310

Original commit changeset: a149751f2d67

fbshipit-source-id: 043240201d67966dd08b7b1bc2f9bf4897923e00

* Implement pickle support for sparse tensors and torch.layout instances (#27062)

Summary:
Resolves issue https://github.com/pytorch/pytorch/issues/16667 and https://github.com/OpenMined/PySyft/issues/2326
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27062

Differential Revision: D17762932

Pulled By: ezyang

fbshipit-source-id: dd99c1f4ac8eb2286eb55aa20ce973f60ce7b7e1

* move new_zeros to core from THP (#26511)

Summary:
Fix for issue https://github.com/pytorch/pytorch/issues/25831

ezyang can you please have a look?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26511

Differential Revision: D17763037

Pulled By: ezyang

fbshipit-source-id: 3596c01c4ab421e7785d6055cc813806f840a5c7

* autograd: double backwards function for binary_cross_entropy loss

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26983

Reviewed By: albanD

Differential Revision: D17714357

Pulled By: anjali411

fbshipit-source-id: cebfe09a9048c4be457b7f2718bc396c06ecabee

* Change schedulers to chainable form (#26423)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26423

Enable chainable schedulers as requested in #13022 by implementing the changes mentioned below from [comment](https://github.com/pytorch/pytorch/pull/21800#issuecomment-513370208).

* Changing the behavior of schedulers to the chainable formula when available
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning
* Making `get_computed_values` the supported way of obtaining the last computed learning rate by the scheduler (see [comment](https://github.com/pytorch/pytorch/pull/21800#issuecomment-513940729) for new syntax)
* Returning a deprecation warning when invoking the undocumented get_lr function (see [comment](https://github.com/pytorch/pytorch/pull/21800#discussion_r294305485)) referring to `get_computed_values`, and deprecating it in the next release.
* `CosineAnnealingWarmRestart` still takes an epoch parameter as it is the only one with a mechanic relying on fractional epoch
* `MultiplicativeLR` is consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax.

# #20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters())
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters())
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_computed_values())
```

# #22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_computed_values())
    lr_scheduler.step()
```

# ghstack

This contains the changes from #24352. Opening again since they were reverted.

This reverts commit 1c477b7e1f378e9c1f8efed296241f68a8a4372b.

Test Plan: Imported from OSS

Differential Revision: D17460427

Pulled By: vincentqb

fbshipit-source-id: 8c10f4e7246d6756ac91df734e8bed65bdef63c9

* Make RpcTest re-usable by other RPC backends by using init_method to initialize a RPC backend (#27320)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27320

https://github.com/pytorch/pytorch/pull/27208/

# Problem

Other RPC backends take init_method.

# Solution

Set up init_method in rpc tests.
ghstack-source-id: 91335127

Differential Revision: D17709219

fbshipit-source-id: 3184c6e9b922a6ff9f4d1cb9abfa118b23f43eeb

* Add OPN instruction and vararg operator table (#27104)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27104

* The use case here is to replace prim::ListConstruct, which requires Node, but Node is not available in mobile lite interpreter.
* (OPN, X, N), X is the index to the vararg operator-name and operator tables. N is number of inputs. For ListConstruct example, operator name can be "aten::listconstruct" and the overloaded name is the output type ("int", "float", "bool", "tensor" and "generic").
* A vararg operator table is built with void(int input_size, Stack& stack) functions.
## Unit test
LiteInterpreterConv covers OPN instruction and conv operator.

Test Plan: Imported from OSS

Differential Revision: D17762853

fbshipit-source-id: 475aa0c6678e3760cec805862a78510913a89c83

* Allow use cpu_serial_kernel with void-lambda (#27370)

Summary:
https://github.com/pytorch/pytorch/pull/27271
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27370

Differential Revision: D17763265

Pulled By: ifedan

fbshipit-source-id: d670560dfc555db529b18c01aa42f0ccb2127889

* From docs of scatter_add_() removed erroneous comment on uniqueness of indices. (#27132)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/27080
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27132

Differential Revision: D17765307

Pulled By: soumith

fbshipit-source-id: b0892ff442f3b49f8e3cdf029e2a08b51fa88f28

* Reduce error context from 10 -> 3 (#26765)

Summary:
10 lines of error context (on both sides) is overkill, especially now
that we have line numbers. With a compilation stack of a couple
functions, it becomes a pain to scroll to the top of the stack to see
the real error every time.

This also fixes class names in the compilation stack to a format of
`ClassName.method_name` instead of the the full qualified name
Old output
```
clip_boxes_to_image(Tensor boxes, (int, int) size) -> (Tensor):
Expected a value of type 'Tuple[int, int]' for argument 'size' but instead found type 'Tuple[int, int, int]'.
:
at /home/davidriazati/dev/vision/torchvision/models/detection/rpn.py:365:20
        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
        batch_idx = torch.arange(num_images, device=device)[:, None]
        objectness = objectness[batch_idx, top_n_idx]
        levels = levels[batch_idx, top_n_idx]
        proposals = proposals[batch_idx, top_n_idx]

        final_boxes = []
        final_scores = []
        for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            keep = box_ops.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
            # non-maximum suppression, independently done per level
            keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
            # keep only topk scoring predictions
            keep = keep[:self.post_nms_top_n]
            boxes, scores = boxes[keep], scores[keep]
            final_boxes.append(boxes)
            final_scores.append(scores)
'RegionProposalNetwork.filter_proposals' is being compiled since it was called from 'RegionProposalNetwork.forward'
at /home/davidriazati/dev/vision/torchvision/models/detection/rpn.py:446:8
        num_images = len(anchors)
        num_anchors_per_level = [o[0].numel() for o in objectness]
        objectness, pred_bbox_deltas = \
            concat_box_prediction_layers(objectness, pred_bbox_deltas)
        # apply pred_bbox_deltas to anchors to obtain the decoded proposals
        # note that we detach the deltas because Faster R-CNN do not backprop through
        # the proposals
        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

        losses = {}
        if self.training:
            assert targets is not None
            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
            loss_objectness, loss_rpn_box_reg = self.compute_loss(
                objectness, pred_bbox_deltas, labels, regression_targets)
            losses = {
'RegionProposalNetwork.forward' is being compiled since it was called from 'MaskRCNN.forward'
at /home/davidriazati/dev/vision/torchvision/models/detection/generalized_rcnn.py:53:8
        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        original_image_sizes = [(img.shape[-2], img.shape[-3])  for img in images]

        images, targets = self.transform(images, targets)
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([(0, features)])
        proposals, proposal_losses = self.rpn(images, features, targets)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

        # TODO: multiple return types??
        # if self.training:
```

New output

```
RuntimeError:

clip_boxes_to_image(Tensor boxes, (int, int) size) -> (Tensor):
Expected a value of type 'Tuple[int, int]' for argument 'size' but instead found type 'Tuple[int, int, int]'.
:
at /home/davidriazati/dev/vision/torchvision/models/detection/rpn.py:365:20
        final_scores = []
        for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            keep = box_ops.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
'RegionProposalNetwork.filter_proposals' is being compiled since it was called from 'RegionProposalNetwork.forward'
at /home/davidriazati/dev/vision/torchvision/models/detection/rpn.py:446:8
        prop
10BC0
osals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

        losses = {}
'RegionProposalNetwork.forward' is being compiled since it was called from 'MaskRCNN.forward'
at /home/davidriazati/dev/vision/torchvision/models/detection/generalized_rcnn.py:53:8
        if isinstance(features, torch.Tensor):
            features = OrderedDict([(0, features)])
        proposals, proposal_losses = self.rpn(images, features, targets)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
        detections = self.transform.postprocess
```
](https://our.intern.facebook.com/intern/diff/17560963/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26765

Pulled By: driazati

Differential Revision: D17560963

fbshipit-source-id: e463548744b505ca17f0158079b80e08fda47d49

* Fix some return std::move warnings (#27384)

Summary:
clang-tidy was complaining about these
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27384

Pulled By: driazati

Differential Revision: D17767412

fbshipit-source-id: 03e2630790edf3f6bbf9064e754156613032b464

* add function to get nccl version for error messages (#27068)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27068

Adds a function that uses ncclGetVersion from the NCCL API to retrieve the NCCL version. Converts it into a readable string, and is called in NCCL-related error messages to log the NCCL version. Hopefully this will help with debugging NCCL errors.

Test Plan:
Modify C10D_NCCL_CHECK in NCCLUtils.hpp to always error by setting ncclResult_t error = ncclSystemError
force an NCCL error with script test/simulate_nccl_errors.py:
Start master node: python test/simulate_nccl_errors.py localhost 9124 0 2
Start other node: python test/simulate_nccl_errors.py localhost 9124 1 2
On the master node, should see the following error message w/NCCL version:

```
Traceback (most recent call last):
  File "simulate_nccl_errors.py", line 29, in <module>
    process_group.allreduce(torch.rand(10).cuda(rank)).wait()
RuntimeError: NCCL error in: ../torch/lib/c10d/ProcessGroupNCCL.cpp:375, unhandled system error, NCCL version 2.4.8
```

Differential Revision: D17639476

fbshipit-source-id: a2f558ad9e883b6be173cfe758ec56cf140bc1ee

* C++ API parity: Hardtanh

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27038

Test Plan: Imported from OSS

Differential Revision: D17682405

Pulled By: pbelevich

fbshipit-source-id: f65e76696e0041c3518f56da94f2e3b800305234

* fix OSX CI build (#27373)

Summary:
fix OSX caffe2 CI build, attempt 1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27373

Differential Revision: D17768461

Pulled By: soumith

fbshipit-source-id: b0a076c07382327730b5d86b8a00f5388c368b5e

* ProcessGroupNCCL should respect timeout passed in to init_process_group. (#27224)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27224

As part of adding error handling to NCCL, we are now able to specify a
timeout for operations using ProcessGroupNCCL. Although, this timeout had a
default of 10 seconds and didn't respect the timeout specified in
init_process_group.

In this change, I've ensured we pass the appropriate timeout to
ProcessGroupNCCL.
ghstack-source-id: 91283548

Test Plan:
Added unit test to verify timeout passed in to init_process_group is
respected.

Differential Revision: D17717992

fbshipit-source-id: c73320187f1f3b2693ba1e177d80646e282d01a2

* Add clip_grad_norm_ to c++ api (#26140)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26140

Per https://github.com/pytorch/pytorch/issues/25883, we want to work
towards C++/Python API parity. This diff adds clip_grad_norm_ to the c++ API to
improve parity.

ghstack-source-id: 91334333
ghstack-source-id: 91334333

Test Plan: Added a unit test

Differential Revision: D17312367

fbshipit-source-id: 753ba3a4d084d01f3cc8919da3108e67c809ad65

* C++ API parity: LeakyReLU

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27059

Test Plan: Imported from OSS

Differential Revision: D17682407

Pulled By: pbelevich

fbshipit-source-id: 2a4f42e9438799ba8de7282ac7a6fd3ff97ee048

* Some hipify script cleanups (#27375)

Summary:
continue https://github.com/pytorch/pytorch/issues/26363
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27375

Differential Revision: D17764992

Pulled By: bddppq

fbshipit-source-id: ecc06521179677efcedb1d58ceda63df7d63627e

* add some support for the occupancy API on ROCm (#27390)

Summary:
Unfortunately, the HIP function takes uint32_t* instead of int*, so we still need to ifdef for the time being.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27390

Differential Revision: D17768832

Pulled By: bddppq

fbshipit-source-id: c65176660cb0783a04f0a4a064f686818d759589

* Add gfx908 to the list of per-default compiled architectures. (#27388)

Summary:
ROCm 2.8 added preliminary support for gfx908.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27388

Differential Revision: D17767772

Pulled By: bddppq

fbshipit-source-id: 172daf5bb66d3db86a13e287059af4b9b90a7f57

* Change nightly builds version to 1.4.0-SNAPSHOT (#27381)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27381

Changing android nightly builds from master to version 1.4.0-SNAPSHOT, as we also have 1.3.0-SNAPSHOT from the branch v1.3.0

Test Plan: Imported from OSS

Differential Revision: D17773620

Pulled By: IvanKobzarev

fbshipit-source-id: c39a1dbf5e06f79c25367c3bc602cc8ce42cd939

* Pickup proxy parameters for publishing (#27389)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27389

Pickup gradle proxy parameters (handy for publishing from devserver) in maven publishing gradle plugin

Test Plan: Imported from OSS

Differential Revision: D17773548

Pulled By: IvanKobzarev

fbshipit-source-id: 662c0b2835e6cf1e4009da79e27268d4a19c2ceb

* MovingAverage Observer (#27396)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27396

Observer that estimates moving averages of min and max values per batch,  more suited for quantization aware training instead of minmax observers that track extremal values across batches
ghstack-source-id: 91369018

Test Plan:
buck test caffe2/test:quantization -- 'test_per_tensor_observers \(test_quantization\.ObserverTest\)' --print-passing-details

buck test caffe2/test:quantization -- 'test_per_channel_observers \(test_quantization\.ObserverTest\)' --print-passing-details

Differential Revision: D17727213

fbshipit-source-id: 024a890bf3dd0bf269d8bfe61f19871d027326f0

* Add methods to write image tensor content to buffer (#27359)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27359

Adding methods  to TensorImageUtils:
```
bitmapToFloatBuffer(..., FloatBuffer outBuffer, int outBufferOffset)
imageYUV420CenterCropToFloat32Tensor(..., FloatBuffer outBuffer, int outBufferOffset)
```
To be able to
 - reuse FloatBuffer for inference
 - to create batch-Tensor (contains several images/bitmaps)

As we reuse FloatBuffer for example demo app - image classification,
profiler shows less memory allocations (before that for every run we created new input tensor with newly allocated FloatBuffer) and ~-20ms on my PixelXL

Known open question:
At the moment every tensor element is written separatly calling `outBuffer.put()`, which is native call crossing lang boundaries
As an alternative - to allocation `float[]` on java side and fill it and put it in `outBuffer` with one call, reducing native calls, but increasing memory allocation on java side.
Tested locally just eyeballing durations - have not noticed big difference - decided to go with less memory allocations.

Will be good to merge into 1.3.0, but if not - demo app can use snapshot dependencies with this change.

PR with integration to demo app:
https://github.com/pytorch/android-demo-app/pull/6

Test Plan: Imported from OSS

Differential Revision: D17758621

Pulled By: IvanKobzarev

fbshipit-source-id: b4f1a068789279002d7ecc0bc680111f781bf980

* add warning to dnnlowp fc if quantization kind is not min_max

Summary:
Print warning when using DNNLOWP dynamic int8 quant for FC and activation_quantization_kind != min_max.

Warning will display in console but not in Bento. Would have to use CAFFE_ENFORCE to alert in Bento.

Test Plan: buck run unit test forcing DNNLOWP FC with activation_quantization_kind = "l2" and saw warning printed in console.

Reviewed By: csummersea

Differential Revision: D17770921

fbshipit-source-id: b6532e4c9a86d74e3db4cb432735505d378a366e

* Add interface/object serialization as module attribute (#26770)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26770

This PR added the interface/object serialization as module attribute, to
allow initializing object as a interface type during python
initialization. Because interface type can be backed by any class object
that implements that interface, if we declare it in
python/module.__init__, we will need to collect the run time types of the
value and serialize them to ensure complete code information

Test Plan: Imported from OSS

Differential Revision: D17742707

fbshipit-source-id: 7f614ad4f982996d320a0e2dd3515bf47370e730

* Adding docstrings for nnq.functional

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27363

Test Plan: Imported from OSS

Differential Revision: D17758907

Pulled By: zafartahirov

fbshipit-source-id: f560f2726cf51ceebdbf22ebef2d067422340cf2

* Enable RCCL in ROCm build (#27383)

Summary:
continues https://github.com/pytorch/pytorch/pull/23884
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27383

Differential Revision: D17767248

Pulled By: bddppq

fbshipit-source-id: 3a506844ca6f01d7bbe8be5bde0976999e3a2b90

* Add randomFill to test_utils.h

Summary: Add helper function randomFill to test_utils.h so we can use it in benchmark scrips as well tests.

Test Plan:
```
buck run mode/opt //tvm/sparse:cblas_bench
```

Reviewed By: yinghai

Differential Revision: D17759193

fbshipit-source-id: e4909b04e83ca9382ab4718855fb63743d028de1

* Use deepcopy inputs for ONNX ort test cases (#27186)

Summary:
Running models with inplace operators will change values of input tensors.
Deepcopy input tensors each time to keep the original input tensors intact.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27186

Differential Revision: D17776598

Pulled By: jerryzh168

fbshipit-source-id: d4808a11185a9ab0d782a62d7d708dfe7e94559c

* Remove dependency on six from dist_autograd_test.py

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27369

Test Plan: Imported from OSS

Differential Revision: D17763104

Pulled By: mrshenli

fbshipit-source-id: dd146809686e7720f2b77012eebb6aed72851556

* Docstring fix (#27225)

Summary:
Correcting docstring for `add_image_with_boxes` method. Fixed spelling mistake.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27225

Differential Revision: D17776604

Pulled By: jerryzh168

fbshipit-source-id: 45f69643ec3b58c46b9fb67411c42a6d09b7290e

* Tweak docs on building docs

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27364

Differential Revision: D17777402

Pulled By: dzhulgakov

fbshipit-source-id: 304c678e5c80d7f8c779d65c11f9bf1b0facdb52

* Upgrade to ROCm 2.9 (#27417)

Summary:
New docker images built with tag 325: https://ci.pytorch.org/jenkins/job/caffe2-docker-trigger/325

Related ossci-job-dsl commits:
https://github.com/pytorch/ossci-job-dsl/commit/a00a76f927944aed961a3bbbc4f17aff0fc30d71
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27417

Differential Revision: D17777517

Pulled By: bddppq

fbshipit-source-id: a6b8cb86b37f537d402f6d2c7d28ad28a6a5a317

* enable rocTX API (#27416)

Summary:
ROCm 2.9 brings support for the rocTX API through rocTracer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27416

Differential Revision: D17777480

Pulled By: bddppq

fbshipit-source-id: 6bce9b54c94e5b4c5787570d2b85736882bd23a7

* C++ API parity: LogSigmoid

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27060

Test Plan: Imported from OSS

Differential Revision: D17682404

Pulled By: pbelevich

fbshipit-source-id: d60d64cd4caf1f56a2e05c516f91321d46ec9624

* Remove Tensor.h, TensorMethods.h from src/core. (#27086)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27086

This is a major source of merge conflicts, and AFAICT isn't necessary anymore (it may have been necessary for some mobile build stuff in the past).

This is a commandeer of #25031

Test Plan: Imported from OSS

Reviewed By: ljk53

Differential Revision: D17687345

Pulled By: ezyang

fbshipit-source-id: bf6131af835ed1f9e3c10699c81d4454a240445f

* Remove outdated note in cholesky_solve and triangular_solve doc strings (#26989)

Summary:
We do support inputs with dim > 2 in _out variants
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26989

Differential Revision: D17785632

Pulled By: soumith

fbshipit-source-id: d42ba7ca9c225ad1a26ff3b410d0c5c08eaed001

* Disable tsan for test_multiprocessing. (#27410)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27410

Similar to https://github.com/pytorch/pytorch/pull/25005, TSAN is not
safe to use in a multi-threaded program with fork and can cause deadlocks. As a
result, disabling this test for TSAN.
ghstack-source-id: 91393545

Test Plan: buildbot

Differential Revision: D17775141

fbshipit-source-id: 109b8095240ad43ee4a6380f70b9efca863c0a4a

* Unfold export (#24970)

Summary:
ONNX export for Unfold in symbolic opset9 + op and ORT tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24970

Reviewed By: hl475

Differential Revision: D17495106

Pulled By: houseroad

fbshipit-source-id: fcd179a1213c0f219628f25c09e66fcfe4c5df50

* Reduce special casing around 'training' (#27109)

Summary:
Most of this was old cruft left over from special handling of `training` before we had a `bool` type. This makes all modules have a `training` attribute that is true by default and removes all other special handling.

Fixes #26884
](https://our.intern.facebook.com/intern/diff/17728129/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27109

Pulled By: driazati

Differential Revision: D17728129

fbshipit-source-id: 8ddc9fbb07a953dd05529538bfdd01ed88b5cb57

* Put metrics back to torch.utils.tensorboard similar we have in TensorboardX

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27252

Test Plan: Check metrics in the Scuba table: https://fburl.com/scuba/k5x8yosj

Reviewed By: sanekmelnikov

Differential Revision: D17723414

fbshipit-source-id: 64d42e0b4582f635d38f38feb2b2a6c4826f2065

* Automatic update of fbcode/onnx to 2891e1459745933f4bba9a8cb3371cf3c9eb1d16 (#27474)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27474

Previous import was 034921bd574cc84906b7996c07873454b7dd4135

Included changes:
- **[2891e145](https://github.com/onnx/onnx/commit/2891e145)**: Fix Unique unit test (#2381) <Scott McKay>
- **[25cf73e5](https://github.com/onnx/onnx/commit/25cf73e5)**: update shapeInference h file link (#2369) <prcvih>
- **[e3074bc0](https://github.com/onnx/onnx/commit/e3074bc0)**: modify file path (#2378) <prcvih>
- **[9058d3a4](https://github.com/onnx/onnx/commit/9058d3a4)**: Incrementing version number to 1.6.0 (#2353) (#2385) <Kevin Chen>
- **[c963586d](https://github.com/onnx/onnx/commit/c963586d)**: Remove typing packages from test requirements (#2375) <Aiken Cairncross>

Test Plan: ci

Reviewed By: bddppq

Differential Revision: D17791527

fbshipit-source-id: 23ad5abe313cd4e4eedcbe7794b98450b3b7d3bc

* Fixed Select symbolic to export slice when index = negative one (#25273)

Summary:
Exporting torch.select when index = negative one (x[:,-1]) was broken. This PR has the fix in symbolic function for select.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25273

Reviewed By: hl475

Differential Revision: D17159707

Pulled By: houseroad

fbshipit-source-id: 2c3b275421082758f1b63c1c9b6e578f03ca9f76

* Avoid variable shadowing in ``::at::philox_engine::single_round()`` (#27486)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27486

Rename `key` argument of `single_round` method to `in_key`

Test Plan: CI

Reviewed By: stepancheg, soumith

Differential Revision: D17782904

fbshipit-source-id: 6feae55c407f39d41db099b013dcbd3990768603

* Refactor python_android test to separate Android-specific components (#27453)

Summary:
All of the test cases move into a base class that is extended by the
intrumentation test and a new "HostTests" class that can be run in
normal Java.  (Some changes to the build script and dependencies are
required before the host test can actually run.)

ghstack-source-id: fe1165b513241b92c5f4a81447f5e184b3bfc75e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27453

Test Plan: Imported from OSS

Reviewed By: IvanKobzarev

Differential Revision: D17800410

fbshipit-source-id: 1184f0caebdfa219f4ccd1464c67826ac0220181

* Various cleanups to pytorch_android API (#27454)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27454

See detailed discussion at
https://github.com/pytorch/pytorch/issues/27350

Test Plan: Imported from OSS

Reviewed By: IvanKobzarev

Differential Revision: D17800480

Pulled By: dreiss

fbshipit-source-id: bf174e8b16231b89be771de0fa54c41e864a3eb0

* Clean up JavaDoc comments in pytorch_android

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27455

Test Plan: Imported from OSS

Differential Revision: D17800658

Pulled By: dreiss

fbshipit-source-id: dbd01d9fa5ac82c50daf54c2869dc18be233d8dd

* FunctionEventAvg implements __iadd__ interface (#27498)

Summary:
Resolving issue https://github.com/pytorch/pytorch/issues/26433 by making FunctionEventAvg implement the `__iadd__` interface again, like it used to.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27498

Differential Revision: D17801918

Pulled By: ezyang

fbshipit-source-id: 0597059c903ac168ed64a05ac1decff3ffd14f06

* Move hipify to torch/utils to bundle them into torch package (#27425)

Summary:
Similar to https://github.com/pytorch/pytorch/pull/27418 but try to put it under "torch" namespace
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27425

Differential Revision: D17779490

Pulled By: bddppq

fbshipit-source-id: 688338d143509b37dfc110df17af3331db48a42b

* Ensure NCCL error handling code is disabled for NCCL versions < 2.4 (#27124)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27124

ncclCommAbort() and ncclGetAsyncError() were two APIs added in NCCL
2.4 to detect errors in NCCL communicators. These were used as part of
ProcesGroupNCCL and we also enforced that only NCCL versions 2.4+ were
supported. Although, there is still legitimate use for older NCCL versions and
hence we should still support those.

For that purpose, in this change I've ensured we disable NCCL error checking
for versions < 2.4.
ghstack-source-id: 91452959

Test Plan:
1) Test with 2.4.8
2) Test with 2.2.13
3) unit tests.

Differential Revision: D17178988

fbshipit-source-id: 5dc44b5f7b4b00466c67fd452315f1d4f5c47698

* #include <stdexcept> into flat_hash_map.h (#27478)

Summary:
Fixing https://github.com/pytorch/pytorch/issues/27266

In general we should not rely on transitively included headers, we should implicitly include all headers if their members are used in the source file.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27478

Differential Revision: D17799522

Pulled By: pbelevich

fbshipit-source-id: 5818394a212c947cfac3a6cf042af9ebb8b9d9a0

* Fix broken name mangling

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27511

Test Plan: Imported from OSS

Differential Revision: D17801185

Pulled By: jamesr66a

fbshipit-source-id: 3eaa9542a445c9401f3f96e11138ec09b0d8350a

* Updating submodules

Summary:
GitHub commits:

https://github.com/facebook/fbthrift/commit/e80ecd1d63c956ed34b257fbd1aaef73ef8eb781
https://github.com/facebook/proxygen/commit/6c7a36b1b3f2825fd30ba00c708ec5ceaa5db760
https://github.com/facebookincubator/mvfst/commit/875046204325f9bd8cc5343b98a8fa4b99187a3c
https://github.com/facebook/proxygen/commit/442d7def679c297427f5d0b679685db92fe3d28c
https://github.com/facebook/wangle/commit/c138dc3d2c0c4f4f68ab4931e44b87a6becb194c
https://github.com/facebookincubator/fizz/commit/3833f10989711256704260a01e0c9f7d1c33e468
https://github.com/facebookincubator/katran/commit/6fc473d5304985aa31d351c6305904e80af4b614
https://github.com/pytorch/fbgemm/commit/82d259dade58e53775a534f88b7b48e760f09a64

Test Plan: n/a

Reviewed By: 2d2d2d2d2d

fbshipit-source-id: 7834a4a8620d0ab9b60060e0abadfba457fb2890

* Revert D17159707: [pytorch][PR] [ONNX] Fixed Select symbolic to export slice when index = negative one

Test Plan: revert-hammer

Differential Revision:
D17159707

Original commit changeset: 2c3b27542108

fbshipit-source-id: accce910abdbe13270d0f592810a48b1dabe4b01

* Roll master to 1.4.0 (#27374)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27374

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D17809770

Pulled By: ezyang

fbshipit-source-id: 75bd97426494a7bbbf08f9bce7563d35871443d8

* Exponential decay of the weight of task loss (#27508)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27508

Implemented a simple exponential decay of the weight of lr loss function, with a lower bound.

Test Plan:
buck test //caffe2/caffe2/fb/dper/layer_models/tests:mtml_test -- test_task_weight_decay
https://our.intern.facebook.com/intern/testinfra/testrun/3377699729136308

canary: f140103452

Reviewed By: chenshouyuan

Differential Revision: D17524101

fbshipit-source-id: 9a653e21a4ecb74dfc4ac949c9e3388f36ef3a20

* docstring only formatting changes: quantize.py, fake_quantize.py, observer.…
pritamdamania87 pushed a commit that referenced this issue Oct 9, 2019
Pull Request resolved: #27022

This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91586182

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)
pritamdamania87 pushed a commit that referenced this issue Oct 9, 2019
…pass implementation."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 9, 2019
…ion."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 9, 2019
…pass implementation."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 9, 2019
…ion."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 9, 2019
…pass implementation."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 9, 2019
…ion."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 10, 2019
…pass implementation."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 10, 2019
…ion."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 10, 2019
Pull Request resolved: #27022

This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91650306

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)
pritamdamania87 pushed a commit that referenced this issue Oct 11, 2019
…pass implementation."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 11, 2019
…ion."


[test all] This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do. This was mostly done to ensure Future.wait() propagates errors correctly on the backward pass.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Oct 11, 2019
Pull Request resolved: #27022

This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926

Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)
facebook-github-bot pushed a commit that referenced this issue Oct 12, 2019
Summary:
Pull Request resolved: #27022

This change implements the "FAST" mode distributed autograd backward
pass as described in #23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926

Test Plan: unit tests.

Differential Revision: D17652615

fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
@pietern pietern removed the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 5, 2019
facebook-github-bot pushed a commit that referenced this issue Nov 14, 2019
Summary:
Closes #28983. Documentation for `torch.distributed.rpc` and `torch.distributed.autograd` modules. Also fixes/tidies up some of the docstrings in rpc/autograd, and moves some functions to be private so they don't show up in the documentation.

Note: Much of the text to describe/explain the RPC/RRef layers are taken from the following RFCs: #23110, #26759
Pull Request resolved: #29276

Differential Revision: D18478754

Pulled By: rohan-varma

fbshipit-source-id: e9a7089baf5275304e5408d319eb9bf98e53fff8
csarofeen pushed a commit to mruberry/pytorch that referenced this issue Nov 18, 2019
Summary:
Closes pytorch#28983. Documentation for `torch.distributed.rpc` and `torch.distributed.autograd` modules. Also fixes/tidies up some of the docstrings in rpc/autograd, and moves some functions to be private so they don't show up in the documentation.

Note: Much of the text to describe/explain the RPC/RRef layers are taken from the following RFCs: pytorch#23110, pytorch#26759
Pull Request resolved: pytorch#29276

Differential Revision: D18478754

Pulled By: rohan-varma

fbshipit-source-id: e9a7089baf5275304e5408d319eb9bf98e53fff8
pdlive215 pushed a commit to pdlive215/pytorch that referenced this issue Nov 27, 2019
…rch#25527)

Summary:
Pull Request resolved: pytorch#25527

Master GH issue: pytorch#23110.

This change builds upon pytorch#24876 and
provides all the autograd hooks needed for a forward pass with distributed rpc
for builtin operators. This change does not address distributed rpc for python
UDFs and that will be addressed in follow up PRs.

Summary of changes:
1. Attach send autograd functions when a request is sent from the client and
response is sent from the server.
2. Attach receive autograd functions when a request is received on the server
and a response is received on the client.
3. Generate a globally unique autograd_message_id for each send/recv autograd
function pair to uniquely identify them.
ghstack-source-id: 91240466

Test Plan: unit tests.

Differential Revision: D17148077

fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
pdlive215 pushed a commit to pdlive215/pytorch that referenced this issue Nov 27, 2019
Summary:
Pull Request resolved: pytorch#25499

See pytorch#23110 for model parallel design details, and pytorch#26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.

Limitations:
1. No implicit type conversion yet. (pytorch#27099)
2. No failure handling and retry. (pytorch#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (pytorch#27098)
4. Internal RRef control messages are not idempotent yet. (pytorch#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (pytorch#27096)

Main changes:

1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5.  Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.

Test Plan:
Imported from OSS

buck test mode/dev-nosan //caffe2/test:rpc_fork

Differential Revision: D17184146

Pulled By: mrshenli

fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this issue Feb 4, 2020
…rch#25527)

Summary:
Pull Request resolved: pytorch#25527

Master GH issue: pytorch#23110.

This change builds upon pytorch#24876 and
provides all the autograd hooks needed for a forward pass with distributed rpc
for builtin operators. This change does not address distributed rpc for python
UDFs and that will be addressed in follow up PRs.

Summary of changes:
1. Attach send autograd functions when a request is sent from the client and
response is sent from the server.
2. Attach receive autograd functions when a request is received on the server
and a response is received on the client.
3. Generate a globally unique autograd_message_id for each send/recv autograd
function pair to uniquely identify them.
ghstack-source-id: 91240466

Test Plan: unit tests.

Differential Revision: D17148077

fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this issue Feb 4, 2020
Summary:
Pull Request resolved: pytorch#25499

See pytorch#23110 for model parallel design details, and pytorch#26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.

Limitations:
1. No implicit type conversion yet. (pytorch#27099)
2. No failure handling and retry. (pytorch#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (pytorch#27098)
4. Internal RRef control messages are not idempotent yet. (pytorch#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (pytorch#27096)

Main changes:

1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5.  Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.

Test Plan:
Imported from OSS

buck test mode/dev-nosan //caffe2/test:rpc_fork

Differential Revision: D17184146

Pulled By: mrshenli

fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this issue Feb 4, 2020
…ch#27022)

Summary:
Pull Request resolved: pytorch#27022

This change implements the "FAST" mode distributed autograd backward
pass as described in pytorch#23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926

Test Plan: unit tests.

Differential Revision: D17652615

fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
@ZiyiZhu
Copy link
ZiyiZhu commented Apr 9, 2020

Hi,

I am studying the RPC framework for distributed model-parallelism training on multiple machines. I saw the documents and examples for the RPC framework from this link:
https://pytorch.org/docs/stable/rpc.html#distributed-rpc-framework

However, it only shows the examples with two workers, worker 0 and worker 1, and the communication is always like worker 0 -> worker 1 -> worker 0.

Is there an example or document that describes more complicated situations as you showed here? Another example can be worker 0 -> worker 1 ->worker 2 then stop or back to worker 0?
I wonder you could let me know if there is a document containing more workers or the example above:

all functions should be defined on all workers

def worker0_func(c2: Tensor) -> Tensor:
g = torch.rand(2, 2)
h = g + c2
return h

def worker1_func_top() -> Tensor:
a = torch.rand(2, 2)
b = torch.rand(2, 2)
c = a + b
return c

def worker1_func_bottom(c: Tensor, e1: Tensor) -> Tensor:
f = c + e1
return f

def worker2_func(c1: Tensor) -> Tensor:
d = torch.rand(2, 2)
e = c1 + d
return e

on Worker3

c_ref = torch.remote(worker1_func_top, on="Worker1")
h1 = torch.rpc(worker0_func, c_ref, on="Worker0")
e_ref = torch.remote(worker2_func, c_ref, on="Worker2")
f1 = torch.rpc(worker1_funct_bottom, c_ref, e_ref, on="Worker1")
i = h1 + f1
i.sum().backward()

but with finished PyTorch code?

Thank you,
Ziyi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants
0