Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Distributed Keras training APIs with parameter servers #334

Open
wants to merge 44 commits into
base: master
from

Conversation

@rchao
Copy link
Member

@rchao rchao commented Nov 23, 2020

This RFC will be open for comment until Friday, December 7th, 2020.

Distributed Keras training APIs with parameter servers

Status Under review
RFC # 334
Author(s) Rick Chao, Tom O'Malley, Zhenyu Tan, Yuefeng Zhou (Google)
Sponsor Francois Chollet, Priya Gupta (Google)
Updated 2020-11-21

Goals

  • Parameter server training support for Keras compile/fit style training API
  • Minimal code changes across usage with other strategies
  • Minimal performance implications
rchao added 2 commits Nov 23, 2020
Copy link
Member

@fchollet fchollet left a comment

Thank you for the detailed and well-written RFC!


## Background

With the recent release of TF2 parameter server training support ([ddoc](https://github.com/tensorflow/community/blob/master/rfcs/20200306-single-client-parameter-server.md)) ([API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/parameter_server_strategy_v2.py)) ([tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training)), custom training loop (CTL) users have started using the `ParameterServerStrategy` and `ClusterCoordinator` APIs for parameter server style distributed training. `ParameterServerStrategy` provides implementation of variable placement, and APIs for defining computation, and `ClusterCoordinator` provides APIs for dataset creation, asynchronous function scheduling and remote execution. The asynchronicity brought by `ClusterCoordinator` provides scalability and training fault tolerance, and at the same time implications such as the need for remote resource creation.

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

Please start by defining the specialized terms you will use, e.g. coordinator. Make a list of the different terms and their definition (this will also enable us to check whether we have a unified terminology)

This comment has been minimized.

@rchao

rchao Nov 24, 2020
Author Member

Sounds good. Added a "Glossary" section.

model = ... # Building a Keras model
model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created
def dataset_fn():
... # Make use of `preproc_stage` for transformation

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

Add return statement

This comment has been minimized.

@rchao

rchao Nov 24, 2020
Author Member

Added.

model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created
def dataset_fn():
... # Make use of `preproc_stage` for transformation
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

We should generally ask users to return finite datasets and avoid steps_per_epoch, no? Or it is different for PS?

This comment has been minimized.

@yuefengz

yuefengz Nov 24, 2020
Member

Since the ClusterCoordinator.schedule can schedule a function to any worker. So it would be less error-prone to just ask users to create the same repeated, but shuffled differently, datasets on different workers. In this case, steps_per_epoch is needed.

It is right now difficult to achieve visitation guarantee with ClusterCoordinator.schedule but may be easier with tf.data service in the future.

This comment has been minimized.

@omalleyt12

omalleyt12 Nov 24, 2020

Does this mean we should require the user to return an infinite dataset in the PS case? It sounds like the Dataset might hit OutOfRange error otherwise. In that case, we should throw an error upfront if the user passes a finite dataset

This comment has been minimized.

@yuefengz

yuefengz Nov 24, 2020
Member

Yes, we have explicitly documented the requirement: https://www.tensorflow.org/tutorials/distribute/parameter_server_training#more_about_dataset_creation

There is not error checking in ParameterServerStrategy now, but would be nice to add it.

This comment has been minimized.

@rchao

rchao Nov 24, 2020
Author Member

With the current design of parameter server training APIs, error handling allows OutOfRangeError to be reported to user, but the caveat is that it can mean only one worker has exhausted the dataset, while others still have some remaining to be processed. The behavior may likely be different from what the users would expect, especially when dataset is sharded, so the recommendation is to provide an infinite dataset that is shuffled with a different seed.

This comment has been minimized.

@omalleyt12

omalleyt12 Nov 24, 2020

SGTM, IMO then in the DataAdapter class for the DatasetFactory, we should add an error if the dataset isn't infinite that explains this

This comment has been minimized.

@rchao

rchao Dec 2, 2020
Author Member

Yes - added this in DataAdapter class part.

def dataset_fn():
... # Make use of `preproc_stage` for transformation
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
logging.info("result: %r", history)

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

It would also be useful to cover a few crucial details in this snippet: TensorBoard remote monitoring, fault tolerance, and saving artifacts

This comment has been minimized.

@tanzhenyu

tanzhenyu Nov 25, 2020
Contributor

@fchollet It's important that we discuss what is preferred solution here, i.e., dataset, or dataset_fn

```


with a dataset:

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

"...instance" (both examples use a dataset)

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

In this case, the dataset instance is serialized and deserialized on each worker?

This comment has been minimized.

@omalleyt12

omalleyt12 Nov 24, 2020

Yeah, from what I understand though, there are currently issues with trying to serialize and deserialize the Dataset onto each worker, which is why IMO we'll probably have to start with only supporting a "Dataset factory" input type for PS

This comment has been minimized.

@rchao

rchao Nov 24, 2020
Author Member

Done, and I'm thinking the same @omalleyt12

This comment has been minimized.

@rchao

rchao Nov 24, 2020
Author Member

The dataset instance path would require the library to serialize and deserialize the dataset on workers.

```


This logic can be handled by the `CallbackList` object, which already handles converting `tf.Tensor` logs to `NumPy`. However, obtaining these logs requires us to sync the workers, which will result in a slowdown if done every batch. Currently we only plan to sync the workers once every epoch at the end of the epoch.

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

We should not have callbacks include PS-specific logic. The task of synchronizing (and deciding when it's appropriate) should be done in the Model.

Note that it is necessary for us to provide a way to synchronize and run callbacks at higher frequency than once per epoch. This is critical. We could use steps_per_execution to configure this (with a high default value for the PS use case).

This comment has been minimized.

@yuefengz

yuefengz Nov 24, 2020
Member

It is desirable that some callbacks like checkpointing callbacks will be invoked in the middle of an epoch. But increasing the steps_per_execution to a large value has disadvantages: these steps inside a single function call cannot be distributed across workers for load-balancing purposes; worker failure (leading to the interruption of a function execution) will retry all these steps_per_execution steps.

Also checkpointing doesn't necessary need synchronization in asynchronous training. IMO it would be useful for Keras to support time-based callback for checkpointing that doesn't need synchronization. The checkpointing, especially in the middle of an epoch, is mainly for failure recovery purposes and therefore time-based callbacks would easily allow users to know that at most X minutes of training would be lost if there is PS failure.

This comment has been minimized.

@omalleyt12

omalleyt12 Nov 24, 2020

We should not have callbacks include PS-specific logic

Agree, really the only change needed is in tf_utils.to_numpy_or_python_type, which is used by CallbackList when it detects that it needs to convert Tensor logs into NumPy for a batch. A user would write Callbacks without having to know anything about this, they'll still see NumPy data.

The change would roughly be:

def to_numpy_or_python_type(data):
  # Sync workers and get NumPy values
  if isinstance(data, RemoteValue):
    get_strategy().join()
    return data.fetch()
  else:
    # Current logic stays the same for other strategies.
    ...

This comment has been minimized.

@rchao

rchao Nov 24, 2020
Author Member

Re: The task of synchronizing (and deciding when it's appropriate) should be done in the Model, one option is having logs resolved in model.fit, after the steps are done, and before epoch_end callbacks are called. Concretely, in the current code, we do:

      for epoch, iterator in data_handler.enumerate_epochs():
        ...
        with data_handler.catch_stop_iteration():  # This will sync at end of steps
          for step in data_handler.steps():
              tmp_logs = self.train_function(iterator)
              logs = tmp_logs
              ...
        logs = data_handler.resolve_logs(logs)  # convert remote to concrete
        if logs is None:
          raise ValueError('Expect x to be a non-empty array or dataset.')
        epoch_logs = copy.copy(logs)

        # Run validation.
        ...

        callbacks.on_epoch_end(epoch, epoch_logs)  # epoch_logs are concrete values

With this, we would just need to provide an implementation for resolve_logs. In the case with ClusterCoordinatorDataHandler, it would be logs.fetch(), whereas the default is an no-op. I believe this could avoid a change in tf_utils.to_numpy_or_python_type.

This comment has been minimized.

@omalleyt12

omalleyt12 Nov 24, 2020

I think it's OK to have the logic in tf_utils.to_numpy_or_python_type (this utility is basically already just "the thing that takes values returned by Model methods and converts them to what the user should see"

IIUC, Francois's concern was that (please correct me if I'm wrong Francois) users who write Callbacks shouldn't have to be aware of any of this, shouldn't have to know what a RemoteValue is, etc.

As long as we achieve that, either data_handler.resolve_logs or tf_utils.to_numpy_or_python_type seems fine to me, with a slight preference for tf_utils because it consolidates all of our logs munging into one place

This comment has been minimized.

@rchao

rchao Dec 2, 2020
Author Member

Sounds good Tom - and thanks for updating the callback section which reflects the current consensus.


###### Option 1: Detect and support batch-level callbacks

We have a mechanism in `CallbackList` to detect when batch-level callbacks are being used. This mechanism will only block asynchronous TPU execution when batch-level callbacks require it. We can use this mechanism to also only block asynchronous PSStrategy execution when batch-level callbacks require it. This would allow us to support batch-level callbacks, without paying a performance penalty in the case where they are not used.

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

This should be done in the Model. The model should initiate the sync every steps_per_execution steps, and call the callbacks at that time with a dict of plain values.

Callbacks must be agnostic to PS.

This comment has been minimized.

@omalleyt12

omalleyt12 Nov 24, 2020

In addition to the steps_per_execution logic in Model, the CallbackList object has this mechanism so that it only blocks when absolutely necessary. This saves a lot of time for TPUs. This is complimentary to steps_per_execution

But we can re-use this logic while still fully supporting batch-level Callbacks. This logic is just so that, for the users who don't pass batch-level Callbacks, we don't end up unnecessarily slowing down training

For users who pass batch-level Callbacks, the CallbackList object can detect this and pass the logs every batch, with the modifications to tf_utils.to_numpy_or_python_type noted above.


##### Batch-level callbacks

Some users may want to use batch-level `Callback`s. When users use `steps_per_execution=N`, the `Callback`s will only execute every `N` steps, and so batch-level callbacks might not be prohibitively slow for large `N`. However, in most cases, batch-level callbacks will cause a significant slowdown and are likely to be added only in error. We have a few options for handling batch-level Callbacks.

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

I disagree, batch-level callbacks are useful and important. We just need to make sure that they're called at an appropriate frequency (especially by default)

We have a mechanism in `CallbackList` to detect when batch-level callbacks are being used. This mechanism will only block asynchronous TPU execution when batch-level callbacks require it. We can use this mechanism to also only block asynchronous PSStrategy execution when batch-level callbacks require it. This would allow us to support batch-level callbacks, without paying a performance penalty in the case where they are not used.


###### Option 2: Detect, warn, and support batch-level callbacks

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

No to this and the remaining 2 options below.


#### model.evaluate and model.predict

Initially, we aim to have `evaluate` and `predict` to only be carried out on the coordinator. That is, it does not involve distribution via a `ClusterCoordinator`.

This comment has been minimized.

@fchollet

fchollet Nov 23, 2020
Member

Sounds good.

Copy link
Member

@yuefengz yuefengz left a comment

Overall I think we need to distinguish clearly between what is already existent and proposed and what is proposed here.

```


This logic can be handled by the `CallbackList` object, which already handles converting `tf.Tensor` logs to `NumPy`. However, obtaining these logs requires us to sync the workers, which will result in a slowdown if done every batch. Currently we only plan to sync the workers once every epoch at the end of the epoch.

This comment has been minimized.

@yuefengz

yuefengz Nov 24, 2020
Member

It is desirable that some callbacks like checkpointing callbacks will be invoked in the middle of an epoch. But increasing the steps_per_execution to a large value has disadvantages: these steps inside a single function call cannot be distributed across workers for load-balancing purposes; worker failure (leading to the interruption of a function execution) will retry all these steps_per_execution steps.

Also checkpointing doesn't necessary need synchronization in asynchronous training. IMO it would be useful for Keras to support time-based callback for checkpointing that doesn't need synchronization. The checkpointing, especially in the middle of an epoch, is mainly for failure recovery purposes and therefore time-based callbacks would easily allow users to know that at most X minutes of training would be lost if there is PS failure.




##### Option 2: Have an attribute in `ParameterServerStrategy` that holds the `ClusterCoordinator`

This comment has been minimized.

model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created
def dataset_fn():
... # Make use of `preproc_stage` for transformation
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])

This comment has been minimized.

@yuefengz

yuefengz Nov 24, 2020
Member

Since the ClusterCoordinator.schedule can schedule a function to any worker. So it would be less error-prone to just ask users to create the same repeated, but shuffled differently, datasets on different workers. In this case, steps_per_epoch is needed.

It is right now difficult to achieve visitation guarantee with ClusterCoordinator.schedule but may be easier with tf.data service in the future.

rchao and others added 3 commits Nov 24, 2020
Callback changes
model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(...)
data = tf.data.Dataset.from_tensor_slices(...)
SidecarEvaluator(

This comment has been minimized.

@yuefengz

yuefengz Nov 25, 2020
Member

Is the side-car eval API part of the proposal? If so, is there any alternative API you have considered?

This comment has been minimized.

@tanzhenyu

tanzhenyu Nov 25, 2020
Contributor

I think it's here: "Initially, we aim to have evaluate and predict to only be carried out on the coordinator"

This comment has been minimized.

@yuefengz

yuefengz Nov 25, 2020
Member

@tanzhenyu you mean this is not part of the proposal? Then we should make it clear here.

This comment has been minimized.

@rchao

rchao Dec 2, 2020
Author Member

SidecarEvaluator API will be covered in a separate RFC (and is just briefly mentioned in this RFC).



```
with handle_restartable_error():

This comment has been minimized.

@yuefengz

yuefengz Nov 25, 2020
Member

Is this API part of the proposal?

This comment has been minimized.

@rchao

rchao Dec 2, 2020
Author Member

Updated this part. We expect users to use try-except.

omalleyt12 and others added 3 commits Nov 25, 2020
Callback changes based on discussion

If we go with Option (1), we should disallow batch-level `Callback`s, since in this case `ParameterServerStrategy` with batch-level `Callback`s will always be slower than training on a single machine.

If we go with Option (2) we should support batch-level `Callback`s, but we will use existing logic in `CallbackList` to detect when batch-level `Callback`s are passed, and only incur the performance penalty of syncing workers each batch when the user has passed batch-level `Callback`s (for context, none of Keras's built-in `Callbacks` other than the `ProgressBar` will incur this penalty). This logic was originally put in place to ensure that TPU async mode was only blocked when needed, and applies equally well to `ParameterServerStrategy` without significant modifications. We will also re-use existing logic to log a warning to the user when their batch-level `Callback`s are causing a significant slowdown in training time. This logic also resides in the `CallbackList` object.

This comment has been minimized.

@rchao

rchao Nov 25, 2020
Author Member

Question @omalleyt12 - does TPU have ProgressBar by default? It looks like it blocks async when verbose=1, which seems to be the default.

**Option 2: A `step` is one batch on every worker**

With this mental model, every time `Model.train_function` is called, it schedules one batch to execute on each worker. This means that if there are `W` workers, passing `steps_per_epoch=100` will actually run `100 * W` batches of training, with each worker seeing `100` batches.

This comment has been minimized.

@yuefengz

yuefengz Nov 25, 2020
Member

This option would make the total steps dependent on the number of workers, which is not desirable from users' point of view. It is at odds with the programming model provided by the CTL + PSStrategy. The proposed mental model would also break down if we support elasticity in the future.

In this option, the name of the steps_per_epoch argument should be changed to steps_per_epoch_per_worker.

Alternatively, you can sync every W steps only run steps_per_epoch/W steps, effectively turning it into synchronous parameter server training.

model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(...)
data = tf.data.Dataset.from_tensor_slices(...)
SidecarEvaluator(

This comment has been minimized.

@yuefengz

yuefengz Nov 25, 2020
Member

@tanzhenyu you mean this is not part of the proposal? Then we should make it clear here.


**Cons:**

- Asynchronous `Callback`s would be limited in what they could do. Any changes that an asynchronous object makes to a `tf.Variable` (such as the `learning_rate`) would not take effect until the next epoch, since all of the batches were already scheduled before the `Callback` executes.

This comment has been minimized.

@rchao

rchao Nov 25, 2020
Author Member

I think some of the scheduled functions in the same epoch can still pick up the updated variables because the variable read happens at function execution, but not at scheduling time (unless scheduled functions execute super quickly and callback's variable updates are super slow).

- This would require a separate code path in `Model.fit`, since the order in which functions are scheduled and `Callback`s are executed would be different in this approach.
- It's not clear how we should handle it when a user passes a mix of synchronous and asynchronous `Callback`s (for instance, if the user passes in one of our existing built-in `Callback`s in addition to an asynchronous `Callback`).

Asynchronous `Callback`s might be worth exploring in a future extension to the functionality of `Model.fit` + `ParameterServerStrategy` integration, but should likely be out-of-scope for the initial design.

This comment has been minimized.

@rchao

rchao Nov 25, 2020
Author Member

Agreeing with this.

rchao added 8 commits Nov 25, 2020
rchao and others added 5 commits Nov 30, 2020
Batch-level callback updates
@theadactyl theadactyl changed the title Proposing Distributed Keras training APIs with parameter servers RFC RFCProposing Distributed Keras training APIs with parameter servers RFC Dec 1, 2020
@theadactyl theadactyl changed the title RFCProposing Distributed Keras training APIs with parameter servers RFC RFC: Distributed Keras training APIs with parameter servers Dec 1, 2020

In TF2 parameter server training, `ClusterCoordinator` naturally supports a dataset function to be passed in to `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such data factory support, `model.fit` with `dataset_fn` can be implemented by subclassing the existing Keras `DataHandler` (a Keras internal private API) to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see `DataHandler` section below for proposed changes.
**The rationale behind using a `dataset_fn` as opposed to `dataset` was a historical choice as we could not get sharding to work well with fault tolerance.

This comment has been minimized.

@byronyi

byronyi Dec 2, 2020
Contributor

Is it still a hard requirement if the data comes from tf.data service processing in distributed_epoch mode?

This comment has been minimized.

@rchao

rchao Dec 2, 2020
Author Member

I think having tf.data service support would be ideal; workers then would just fetch the next example without having to worry that OutOfRangeError occurs. There would not be sharding needed either.

@ematejska ematejska added this to Needs attention in RFC management via automation Dec 3, 2020
@ematejska ematejska moved this from Needs attention to Open reviews in RFC management Dec 3, 2020
rchao added 15 commits Dec 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
RFC management
  
Open reviews
Linked issues

Successfully merging this pull request may close these issues.

None yet

7 participants
You can’t perform that action at this time.